Jacob McCann
8 years ago
308 changed files with 132774 additions and 0 deletions
@ -0,0 +1,66 @@ |
|||
# 0.10.0 |
|||
|
|||
* feature: Add a test hook (#180) |
|||
* feature: `ParseLevel` is now case-insensitive (#326) |
|||
* feature: `FieldLogger` interface that generalizes `Logger` and `Entry` (#308) |
|||
* performance: avoid re-allocations on `WithFields` (#335) |
|||
|
|||
# 0.9.0 |
|||
|
|||
* logrus/text_formatter: don't emit empty msg |
|||
* logrus/hooks/airbrake: move out of main repository |
|||
* logrus/hooks/sentry: move out of main repository |
|||
* logrus/hooks/papertrail: move out of main repository |
|||
* logrus/hooks/bugsnag: move out of main repository |
|||
* logrus/core: run tests with `-race` |
|||
* logrus/core: detect TTY based on `stderr` |
|||
* logrus/core: support `WithError` on logger |
|||
* logrus/core: Solaris support |
|||
|
|||
# 0.8.7 |
|||
|
|||
* logrus/core: fix possible race (#216) |
|||
* logrus/doc: small typo fixes and doc improvements |
|||
|
|||
|
|||
# 0.8.6 |
|||
|
|||
* hooks/raven: allow passing an initialized client |
|||
|
|||
# 0.8.5 |
|||
|
|||
* logrus/core: revert #208 |
|||
|
|||
# 0.8.4 |
|||
|
|||
* formatter/text: fix data race (#218) |
|||
|
|||
# 0.8.3 |
|||
|
|||
* logrus/core: fix entry log level (#208) |
|||
* logrus/core: improve performance of text formatter by 40% |
|||
* logrus/core: expose `LevelHooks` type |
|||
* logrus/core: add support for DragonflyBSD and NetBSD |
|||
* formatter/text: print structs more verbosely |
|||
|
|||
# 0.8.2 |
|||
|
|||
* logrus: fix more Fatal family functions |
|||
|
|||
# 0.8.1 |
|||
|
|||
* logrus: fix not exiting on `Fatalf` and `Fatalln` |
|||
|
|||
# 0.8.0 |
|||
|
|||
* logrus: defaults to stderr instead of stdout |
|||
* hooks/sentry: add special field for `*http.Request` |
|||
* formatter/text: ignore Windows for colors |
|||
|
|||
# 0.7.3 |
|||
|
|||
* formatter/\*: allow configuration of timestamp layout |
|||
|
|||
# 0.7.2 |
|||
|
|||
* formatter/text: Add configuration option for time format (#158) |
@ -0,0 +1,21 @@ |
|||
The MIT License (MIT) |
|||
|
|||
Copyright (c) 2014 Simon Eskildsen |
|||
|
|||
Permission is hereby granted, free of charge, to any person obtaining a copy |
|||
of this software and associated documentation files (the "Software"), to deal |
|||
in the Software without restriction, including without limitation the rights |
|||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|||
copies of the Software, and to permit persons to whom the Software is |
|||
furnished to do so, subject to the following conditions: |
|||
|
|||
The above copyright notice and this permission notice shall be included in |
|||
all copies or substantial portions of the Software. |
|||
|
|||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
|||
THE SOFTWARE. |
@ -0,0 +1,421 @@ |
|||
# Logrus <img src="http://i.imgur.com/hTeVwmJ.png" width="40" height="40" alt=":walrus:" class="emoji" title=":walrus:"/> [![Build Status](https://travis-ci.org/Sirupsen/logrus.svg?branch=master)](https://travis-ci.org/Sirupsen/logrus) [![GoDoc](https://godoc.org/github.com/Sirupsen/logrus?status.svg)](https://godoc.org/github.com/Sirupsen/logrus) |
|||
|
|||
Logrus is a structured logger for Go (golang), completely API compatible with |
|||
the standard library logger. [Godoc][godoc]. **Please note the Logrus API is not |
|||
yet stable (pre 1.0). Logrus itself is completely stable and has been used in |
|||
many large deployments. The core API is unlikely to change much but please |
|||
version control your Logrus to make sure you aren't fetching latest `master` on |
|||
every build.** |
|||
|
|||
Nicely color-coded in development (when a TTY is attached, otherwise just |
|||
plain text): |
|||
|
|||
![Colored](http://i.imgur.com/PY7qMwd.png) |
|||
|
|||
With `log.SetFormatter(&log.JSONFormatter{})`, for easy parsing by logstash |
|||
or Splunk: |
|||
|
|||
```json |
|||
{"animal":"walrus","level":"info","msg":"A group of walrus emerges from the |
|||
ocean","size":10,"time":"2014-03-10 19:57:38.562264131 -0400 EDT"} |
|||
|
|||
{"level":"warning","msg":"The group's number increased tremendously!", |
|||
"number":122,"omg":true,"time":"2014-03-10 19:57:38.562471297 -0400 EDT"} |
|||
|
|||
{"animal":"walrus","level":"info","msg":"A giant walrus appears!", |
|||
"size":10,"time":"2014-03-10 19:57:38.562500591 -0400 EDT"} |
|||
|
|||
{"animal":"walrus","level":"info","msg":"Tremendously sized cow enters the ocean.", |
|||
"size":9,"time":"2014-03-10 19:57:38.562527896 -0400 EDT"} |
|||
|
|||
{"level":"fatal","msg":"The ice breaks!","number":100,"omg":true, |
|||
"time":"2014-03-10 19:57:38.562543128 -0400 EDT"} |
|||
``` |
|||
|
|||
With the default `log.SetFormatter(&log.TextFormatter{})` when a TTY is not |
|||
attached, the output is compatible with the |
|||
[logfmt](http://godoc.org/github.com/kr/logfmt) format: |
|||
|
|||
```text |
|||
time="2015-03-26T01:27:38-04:00" level=debug msg="Started observing beach" animal=walrus number=8 |
|||
time="2015-03-26T01:27:38-04:00" level=info msg="A group of walrus emerges from the ocean" animal=walrus size=10 |
|||
time="2015-03-26T01:27:38-04:00" level=warning msg="The group's number increased tremendously!" number=122 omg=true |
|||
time="2015-03-26T01:27:38-04:00" level=debug msg="Temperature changes" temperature=-4 |
|||
time="2015-03-26T01:27:38-04:00" level=panic msg="It's over 9000!" animal=orca size=9009 |
|||
time="2015-03-26T01:27:38-04:00" level=fatal msg="The ice breaks!" err=&{0x2082280c0 map[animal:orca size:9009] 2015-03-26 01:27:38.441574009 -0400 EDT panic It's over 9000!} number=100 omg=true |
|||
exit status 1 |
|||
``` |
|||
|
|||
#### Example |
|||
|
|||
The simplest way to use Logrus is simply the package-level exported logger: |
|||
|
|||
```go |
|||
package main |
|||
|
|||
import ( |
|||
log "github.com/Sirupsen/logrus" |
|||
) |
|||
|
|||
func main() { |
|||
log.WithFields(log.Fields{ |
|||
"animal": "walrus", |
|||
}).Info("A walrus appears") |
|||
} |
|||
``` |
|||
|
|||
Note that it's completely api-compatible with the stdlib logger, so you can |
|||
replace your `log` imports everywhere with `log "github.com/Sirupsen/logrus"` |
|||
and you'll now have the flexibility of Logrus. You can customize it all you |
|||
want: |
|||
|
|||
```go |
|||
package main |
|||
|
|||
import ( |
|||
"os" |
|||
log "github.com/Sirupsen/logrus" |
|||
) |
|||
|
|||
func init() { |
|||
// Log as JSON instead of the default ASCII formatter. |
|||
log.SetFormatter(&log.JSONFormatter{}) |
|||
|
|||
// Output to stderr instead of stdout, could also be a file. |
|||
log.SetOutput(os.Stderr) |
|||
|
|||
// Only log the warning severity or above. |
|||
log.SetLevel(log.WarnLevel) |
|||
} |
|||
|
|||
func main() { |
|||
log.WithFields(log.Fields{ |
|||
"animal": "walrus", |
|||
"size": 10, |
|||
}).Info("A group of walrus emerges from the ocean") |
|||
|
|||
log.WithFields(log.Fields{ |
|||
"omg": true, |
|||
"number": 122, |
|||
}).Warn("The group's number increased tremendously!") |
|||
|
|||
log.WithFields(log.Fields{ |
|||
"omg": true, |
|||
"number": 100, |
|||
}).Fatal("The ice breaks!") |
|||
|
|||
// A common pattern is to re-use fields between logging statements by re-using |
|||
// the logrus.Entry returned from WithFields() |
|||
contextLogger := log.WithFields(log.Fields{ |
|||
"common": "this is a common field", |
|||
"other": "I also should be logged always", |
|||
}) |
|||
|
|||
contextLogger.Info("I'll be logged with common and other field") |
|||
contextLogger.Info("Me too") |
|||
} |
|||
``` |
|||
|
|||
For more advanced usage such as logging to multiple locations from the same |
|||
application, you can also create an instance of the `logrus` Logger: |
|||
|
|||
```go |
|||
package main |
|||
|
|||
import ( |
|||
"github.com/Sirupsen/logrus" |
|||
) |
|||
|
|||
// Create a new instance of the logger. You can have any number of instances. |
|||
var log = logrus.New() |
|||
|
|||
func main() { |
|||
// The API for setting attributes is a little different than the package level |
|||
// exported logger. See Godoc. |
|||
log.Out = os.Stderr |
|||
|
|||
log.WithFields(logrus.Fields{ |
|||
"animal": "walrus", |
|||
"size": 10, |
|||
}).Info("A group of walrus emerges from the ocean") |
|||
} |
|||
``` |
|||
|
|||
#### Fields |
|||
|
|||
Logrus encourages careful, structured logging though logging fields instead of |
|||
long, unparseable error messages. For example, instead of: `log.Fatalf("Failed |
|||
to send event %s to topic %s with key %d")`, you should log the much more |
|||
discoverable: |
|||
|
|||
```go |
|||
log.WithFields(log.Fields{ |
|||
"event": event, |
|||
"topic": topic, |
|||
"key": key, |
|||
}).Fatal("Failed to send event") |
|||
``` |
|||
|
|||
We've found this API forces you to think about logging in a way that produces |
|||
much more useful logging messages. We've been in countless situations where just |
|||
a single added field to a log statement that was already there would've saved us |
|||
hours. The `WithFields` call is optional. |
|||
|
|||
In general, with Logrus using any of the `printf`-family functions should be |
|||
seen as a hint you should add a field, however, you can still use the |
|||
`printf`-family functions with Logrus. |
|||
|
|||
#### Hooks |
|||
|
|||
You can add hooks for logging levels. For example to send errors to an exception |
|||
tracking service on `Error`, `Fatal` and `Panic`, info to StatsD or log to |
|||
multiple places simultaneously, e.g. syslog. |
|||
|
|||
Logrus comes with [built-in hooks](hooks/). Add those, or your custom hook, in |
|||
`init`: |
|||
|
|||
```go |
|||
import ( |
|||
log "github.com/Sirupsen/logrus" |
|||
"gopkg.in/gemnasium/logrus-airbrake-hook.v2" // the package is named "aibrake" |
|||
logrus_syslog "github.com/Sirupsen/logrus/hooks/syslog" |
|||
"log/syslog" |
|||
) |
|||
|
|||
func init() { |
|||
|
|||
// Use the Airbrake hook to report errors that have Error severity or above to |
|||
// an exception tracker. You can create custom hooks, see the Hooks section. |
|||
log.AddHook(airbrake.NewHook(123, "xyz", "production")) |
|||
|
|||
hook, err := logrus_syslog.NewSyslogHook("udp", "localhost:514", syslog.LOG_INFO, "") |
|||
if err != nil { |
|||
log.Error("Unable to connect to local syslog daemon") |
|||
} else { |
|||
log.AddHook(hook) |
|||
} |
|||
} |
|||
``` |
|||
Note: Syslog hook also support connecting to local syslog (Ex. "/dev/log" or "/var/run/syslog" or "/var/run/log"). For the detail, please check the [syslog hook README](hooks/syslog/README.md). |
|||
|
|||
| Hook | Description | |
|||
| ----- | ----------- | |
|||
| [Airbrake](https://github.com/gemnasium/logrus-airbrake-hook) | Send errors to the Airbrake API V3. Uses the official [`gobrake`](https://github.com/airbrake/gobrake) behind the scenes. | |
|||
| [Airbrake "legacy"](https://github.com/gemnasium/logrus-airbrake-legacy-hook) | Send errors to an exception tracking service compatible with the Airbrake API V2. Uses [`airbrake-go`](https://github.com/tobi/airbrake-go) behind the scenes. | |
|||
| [Papertrail](https://github.com/polds/logrus-papertrail-hook) | Send errors to the [Papertrail](https://papertrailapp.com) hosted logging service via UDP. | |
|||
| [Syslog](https://github.com/Sirupsen/logrus/blob/master/hooks/syslog/syslog.go) | Send errors to remote syslog server. Uses standard library `log/syslog` behind the scenes. | |
|||
| [Bugsnag](https://github.com/Shopify/logrus-bugsnag/blob/master/bugsnag.go) | Send errors to the Bugsnag exception tracking service. | |
|||
| [Sentry](https://github.com/evalphobia/logrus_sentry) | Send errors to the Sentry error logging and aggregation service. | |
|||
| [Hiprus](https://github.com/nubo/hiprus) | Send errors to a channel in hipchat. | |
|||
| [Logrusly](https://github.com/sebest/logrusly) | Send logs to [Loggly](https://www.loggly.com/) | |
|||
| [Slackrus](https://github.com/johntdyer/slackrus) | Hook for Slack chat. | |
|||
| [Journalhook](https://github.com/wercker/journalhook) | Hook for logging to `systemd-journald` | |
|||
| [Graylog](https://github.com/gemnasium/logrus-graylog-hook) | Hook for logging to [Graylog](http://graylog2.org/) | |
|||
| [Raygun](https://github.com/squirkle/logrus-raygun-hook) | Hook for logging to [Raygun.io](http://raygun.io/) | |
|||
| [LFShook](https://github.com/rifflock/lfshook) | Hook for logging to the local filesystem | |
|||
| [Honeybadger](https://github.com/agonzalezro/logrus_honeybadger) | Hook for sending exceptions to Honeybadger | |
|||
| [Mail](https://github.com/zbindenren/logrus_mail) | Hook for sending exceptions via mail | |
|||
| [Rollrus](https://github.com/heroku/rollrus) | Hook for sending errors to rollbar | |
|||
| [Fluentd](https://github.com/evalphobia/logrus_fluent) | Hook for logging to fluentd | |
|||
| [Mongodb](https://github.com/weekface/mgorus) | Hook for logging to mongodb | |
|||
| [Influxus] (http://github.com/vlad-doru/influxus) | Hook for concurrently logging to [InfluxDB] (http://influxdata.com/) | |
|||
| [InfluxDB](https://github.com/Abramovic/logrus_influxdb) | Hook for logging to influxdb | |
|||
| [Octokit](https://github.com/dorajistyle/logrus-octokit-hook) | Hook for logging to github via octokit | |
|||
| [DeferPanic](https://github.com/deferpanic/dp-logrus) | Hook for logging to DeferPanic | |
|||
| [Redis-Hook](https://github.com/rogierlommers/logrus-redis-hook) | Hook for logging to a ELK stack (through Redis) | |
|||
| [Amqp-Hook](https://github.com/vladoatanasov/logrus_amqp) | Hook for logging to Amqp broker (Like RabbitMQ) | |
|||
| [KafkaLogrus](https://github.com/goibibo/KafkaLogrus) | Hook for logging to kafka | |
|||
| [Typetalk](https://github.com/dragon3/logrus-typetalk-hook) | Hook for logging to [Typetalk](https://www.typetalk.in/) | |
|||
| [ElasticSearch](https://github.com/sohlich/elogrus) | Hook for logging to ElasticSearch| |
|||
| [Sumorus](https://github.com/doublefree/sumorus) | Hook for logging to [SumoLogic](https://www.sumologic.com/)| |
|||
| [Logstash](https://github.com/bshuster-repo/logrus-logstash-hook) | Hook for logging to [Logstash](https://www.elastic.co/products/logstash) | |
|||
| [Logmatic.io](https://github.com/logmatic/logmatic-go) | Hook for logging to [Logmatic.io](http://logmatic.io/) | |
|||
|
|||
|
|||
#### Level logging |
|||
|
|||
Logrus has six logging levels: Debug, Info, Warning, Error, Fatal and Panic. |
|||
|
|||
```go |
|||
log.Debug("Useful debugging information.") |
|||
log.Info("Something noteworthy happened!") |
|||
log.Warn("You should probably take a look at this.") |
|||
log.Error("Something failed but I'm not quitting.") |
|||
// Calls os.Exit(1) after logging |
|||
log.Fatal("Bye.") |
|||
// Calls panic() after logging |
|||
log.Panic("I'm bailing.") |
|||
``` |
|||
|
|||
You can set the logging level on a `Logger`, then it will only log entries with |
|||
that severity or anything above it: |
|||
|
|||
```go |
|||
// Will log anything that is info or above (warn, error, fatal, panic). Default. |
|||
log.SetLevel(log.InfoLevel) |
|||
``` |
|||
|
|||
It may be useful to set `log.Level = logrus.DebugLevel` in a debug or verbose |
|||
environment if your application has that. |
|||
|
|||
#### Entries |
|||
|
|||
Besides the fields added with `WithField` or `WithFields` some fields are |
|||
automatically added to all logging events: |
|||
|
|||
1. `time`. The timestamp when the entry was created. |
|||
2. `msg`. The logging message passed to `{Info,Warn,Error,Fatal,Panic}` after |
|||
the `AddFields` call. E.g. `Failed to send event.` |
|||
3. `level`. The logging level. E.g. `info`. |
|||
|
|||
#### Environments |
|||
|
|||
Logrus has no notion of environment. |
|||
|
|||
If you wish for hooks and formatters to only be used in specific environments, |
|||
you should handle that yourself. For example, if your application has a global |
|||
variable `Environment`, which is a string representation of the environment you |
|||
could do: |
|||
|
|||
```go |
|||
import ( |
|||
log "github.com/Sirupsen/logrus" |
|||
) |
|||
|
|||
init() { |
|||
// do something here to set environment depending on an environment variable |
|||
// or command-line flag |
|||
if Environment == "production" { |
|||
log.SetFormatter(&log.JSONFormatter{}) |
|||
} else { |
|||
// The TextFormatter is default, you don't actually have to do this. |
|||
log.SetFormatter(&log.TextFormatter{}) |
|||
} |
|||
} |
|||
``` |
|||
|
|||
This configuration is how `logrus` was intended to be used, but JSON in |
|||
production is mostly only useful if you do log aggregation with tools like |
|||
Splunk or Logstash. |
|||
|
|||
#### Formatters |
|||
|
|||
The built-in logging formatters are: |
|||
|
|||
* `logrus.TextFormatter`. Logs the event in colors if stdout is a tty, otherwise |
|||
without colors. |
|||
* *Note:* to force colored output when there is no TTY, set the `ForceColors` |
|||
field to `true`. To force no colored output even if there is a TTY set the |
|||
`DisableColors` field to `true` |
|||
* `logrus.JSONFormatter`. Logs fields as JSON. |
|||
|
|||
Third party logging formatters: |
|||
|
|||
* [`logstash`](https://github.com/bshuster-repo/logrus-logstash-hook). Logs fields as [Logstash](http://logstash.net) Events. |
|||
* [`prefixed`](https://github.com/x-cray/logrus-prefixed-formatter). Displays log entry source along with alternative layout. |
|||
* [`zalgo`](https://github.com/aybabtme/logzalgo). Invoking the P͉̫o̳̼̊w̖͈̰͎e̬͔̭͂r͚̼̹̲ ̫͓͉̳͈ō̠͕͖̚f̝͍̠ ͕̲̞͖͑Z̖̫̤̫ͪa͉̬͈̗l͖͎g̳̥o̰̥̅!̣͔̲̻͊̄ ̙̘̦̹̦. |
|||
|
|||
You can define your formatter by implementing the `Formatter` interface, |
|||
requiring a `Format` method. `Format` takes an `*Entry`. `entry.Data` is a |
|||
`Fields` type (`map[string]interface{}`) with all your fields as well as the |
|||
default ones (see Entries section above): |
|||
|
|||
```go |
|||
type MyJSONFormatter struct { |
|||
} |
|||
|
|||
log.SetFormatter(new(MyJSONFormatter)) |
|||
|
|||
func (f *MyJSONFormatter) Format(entry *Entry) ([]byte, error) { |
|||
// Note this doesn't include Time, Level and Message which are available on |
|||
// the Entry. Consult `godoc` on information about those fields or read the |
|||
// source of the official loggers. |
|||
serialized, err := json.Marshal(entry.Data) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("Failed to marshal fields to JSON, %v", err) |
|||
} |
|||
return append(serialized, '\n'), nil |
|||
} |
|||
``` |
|||
|
|||
#### Logger as an `io.Writer` |
|||
|
|||
Logrus can be transformed into an `io.Writer`. That writer is the end of an `io.Pipe` and it is your responsibility to close it. |
|||
|
|||
```go |
|||
w := logger.Writer() |
|||
defer w.Close() |
|||
|
|||
srv := http.Server{ |
|||
// create a stdlib log.Logger that writes to |
|||
// logrus.Logger. |
|||
ErrorLog: log.New(w, "", 0), |
|||
} |
|||
``` |
|||
|
|||
Each line written to that writer will be printed the usual way, using formatters |
|||
and hooks. The level for those entries is `info`. |
|||
|
|||
#### Rotation |
|||
|
|||
Log rotation is not provided with Logrus. Log rotation should be done by an |
|||
external program (like `logrotate(8)`) that can compress and delete old log |
|||
entries. It should not be a feature of the application-level logger. |
|||
|
|||
#### Tools |
|||
|
|||
| Tool | Description | |
|||
| ---- | ----------- | |
|||
|[Logrus Mate](https://github.com/gogap/logrus_mate)|Logrus mate is a tool for Logrus to manage loggers, you can initial logger's level, hook and formatter by config file, the logger will generated with different config at different environment.| |
|||
|
|||
#### Testing |
|||
|
|||
Logrus has a built in facility for asserting the presence of log messages. This is implemented through the `test` hook and provides: |
|||
|
|||
* decorators for existing logger (`test.NewLocal` and `test.NewGlobal`) which basically just add the `test` hook |
|||
* a test logger (`test.NewNullLogger`) that just records log messages (and does not output any): |
|||
|
|||
```go |
|||
logger, hook := NewNullLogger() |
|||
logger.Error("Hello error") |
|||
|
|||
assert.Equal(1, len(hook.Entries)) |
|||
assert.Equal(logrus.ErrorLevel, hook.LastEntry().Level) |
|||
assert.Equal("Hello error", hook.LastEntry().Message) |
|||
|
|||
hook.Reset() |
|||
assert.Nil(hook.LastEntry()) |
|||
``` |
|||
|
|||
#### Fatal handlers |
|||
|
|||
Logrus can register one or more functions that will be called when any `fatal` |
|||
level message is logged. The registered handlers will be executed before |
|||
logrus performs a `os.Exit(1)`. This behavior may be helpful if callers need |
|||
to gracefully shutdown. Unlike a `panic("Something went wrong...")` call which can be intercepted with a deferred `recover` a call to `os.Exit(1)` can not be intercepted. |
|||
|
|||
``` |
|||
... |
|||
handler := func() { |
|||
// gracefully shutdown something... |
|||
} |
|||
logrus.RegisterExitHandler(handler) |
|||
... |
|||
``` |
|||
|
|||
#### Thread safty |
|||
|
|||
By default Logger is protected by mutex for concurrent writes, this mutex is invoked when calling hooks and writing logs. |
|||
If you are sure such locking is not needed, you can call logger.SetNoLock() to disable the locking. |
|||
|
|||
Situation when locking is not needed includes: |
|||
|
|||
* You have no hooks registered, or hooks calling is already thread-safe. |
|||
|
|||
* Writing to logger.Out is already thread-safe, for example: |
|||
|
|||
1) logger.Out is protected by locks. |
|||
|
|||
2) logger.Out is a os.File handler opened with `O_APPEND` flag, and every write is smaller than 4k. (This allow multi-thread/multi-process writing) |
|||
|
|||
(Refer to http://www.notthewizard.com/2014/06/17/are-files-appends-really-atomic/) |
@ -0,0 +1,64 @@ |
|||
package logrus |
|||
|
|||
// The following code was sourced and modified from the
|
|||
// https://bitbucket.org/tebeka/atexit package governed by the following license:
|
|||
//
|
|||
// Copyright (c) 2012 Miki Tebeka <miki.tebeka@gmail.com>.
|
|||
//
|
|||
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|||
// this software and associated documentation files (the "Software"), to deal in
|
|||
// the Software without restriction, including without limitation the rights to
|
|||
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
|||
// the Software, and to permit persons to whom the Software is furnished to do so,
|
|||
// subject to the following conditions:
|
|||
//
|
|||
// The above copyright notice and this permission notice shall be included in all
|
|||
// copies or substantial portions of the Software.
|
|||
//
|
|||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
|||
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
|||
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
|||
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|||
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|||
|
|||
import ( |
|||
"fmt" |
|||
"os" |
|||
) |
|||
|
|||
var handlers = []func(){} |
|||
|
|||
func runHandler(handler func()) { |
|||
defer func() { |
|||
if err := recover(); err != nil { |
|||
fmt.Fprintln(os.Stderr, "Error: Logrus exit handler error:", err) |
|||
} |
|||
}() |
|||
|
|||
handler() |
|||
} |
|||
|
|||
func runHandlers() { |
|||
for _, handler := range handlers { |
|||
runHandler(handler) |
|||
} |
|||
} |
|||
|
|||
// Exit runs all the Logrus atexit handlers and then terminates the program using os.Exit(code)
|
|||
func Exit(code int) { |
|||
runHandlers() |
|||
os.Exit(code) |
|||
} |
|||
|
|||
// RegisterExitHandler adds a Logrus Exit handler, call logrus.Exit to invoke
|
|||
// all handlers. The handlers will also be invoked when any Fatal log entry is
|
|||
// made.
|
|||
//
|
|||
// This method is useful when a caller wishes to use logrus to log a fatal
|
|||
// message but also needs to gracefully shutdown. An example usecase could be
|
|||
// closing database connections, or sending a alert that the application is
|
|||
// closing.
|
|||
func RegisterExitHandler(handler func()) { |
|||
handlers = append(handlers, handler) |
|||
} |
@ -0,0 +1,26 @@ |
|||
/* |
|||
Package logrus is a structured logger for Go, completely API compatible with the standard library logger. |
|||
|
|||
|
|||
The simplest way to use Logrus is simply the package-level exported logger: |
|||
|
|||
package main |
|||
|
|||
import ( |
|||
log "github.com/Sirupsen/logrus" |
|||
) |
|||
|
|||
func main() { |
|||
log.WithFields(log.Fields{ |
|||
"animal": "walrus", |
|||
"number": 1, |
|||
"size": 10, |
|||
}).Info("A walrus appears") |
|||
} |
|||
|
|||
Output: |
|||
time="2015-09-07T08:48:33Z" level=info msg="A walrus appears" animal=walrus number=1 size=10 |
|||
|
|||
For a full guide visit https://github.com/Sirupsen/logrus
|
|||
*/ |
|||
package logrus |
@ -0,0 +1,275 @@ |
|||
package logrus |
|||
|
|||
import ( |
|||
"bytes" |
|||
"fmt" |
|||
"os" |
|||
"sync" |
|||
"time" |
|||
) |
|||
|
|||
var bufferPool *sync.Pool |
|||
|
|||
func init() { |
|||
bufferPool = &sync.Pool{ |
|||
New: func() interface{} { |
|||
return new(bytes.Buffer) |
|||
}, |
|||
} |
|||
} |
|||
|
|||
// Defines the key when adding errors using WithError.
|
|||
var ErrorKey = "error" |
|||
|
|||
// An entry is the final or intermediate Logrus logging entry. It contains all
|
|||
// the fields passed with WithField{,s}. It's finally logged when Debug, Info,
|
|||
// Warn, Error, Fatal or Panic is called on it. These objects can be reused and
|
|||
// passed around as much as you wish to avoid field duplication.
|
|||
type Entry struct { |
|||
Logger *Logger |
|||
|
|||
// Contains all the fields set by the user.
|
|||
Data Fields |
|||
|
|||
// Time at which the log entry was created
|
|||
Time time.Time |
|||
|
|||
// Level the log entry was logged at: Debug, Info, Warn, Error, Fatal or Panic
|
|||
Level Level |
|||
|
|||
// Message passed to Debug, Info, Warn, Error, Fatal or Panic
|
|||
Message string |
|||
|
|||
// When formatter is called in entry.log(), an Buffer may be set to entry
|
|||
Buffer *bytes.Buffer |
|||
} |
|||
|
|||
func NewEntry(logger *Logger) *Entry { |
|||
return &Entry{ |
|||
Logger: logger, |
|||
// Default is three fields, give a little extra room
|
|||
Data: make(Fields, 5), |
|||
} |
|||
} |
|||
|
|||
// Returns the string representation from the reader and ultimately the
|
|||
// formatter.
|
|||
func (entry *Entry) String() (string, error) { |
|||
serialized, err := entry.Logger.Formatter.Format(entry) |
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
str := string(serialized) |
|||
return str, nil |
|||
} |
|||
|
|||
// Add an error as single field (using the key defined in ErrorKey) to the Entry.
|
|||
func (entry *Entry) WithError(err error) *Entry { |
|||
return entry.WithField(ErrorKey, err) |
|||
} |
|||
|
|||
// Add a single field to the Entry.
|
|||
func (entry *Entry) WithField(key string, value interface{}) *Entry { |
|||
return entry.WithFields(Fields{key: value}) |
|||
} |
|||
|
|||
// Add a map of fields to the Entry.
|
|||
func (entry *Entry) WithFields(fields Fields) *Entry { |
|||
data := make(Fields, len(entry.Data)+len(fields)) |
|||
for k, v := range entry.Data { |
|||
data[k] = v |
|||
} |
|||
for k, v := range fields { |
|||
data[k] = v |
|||
} |
|||
return &Entry{Logger: entry.Logger, Data: data} |
|||
} |
|||
|
|||
// This function is not declared with a pointer value because otherwise
|
|||
// race conditions will occur when using multiple goroutines
|
|||
func (entry Entry) log(level Level, msg string) { |
|||
var buffer *bytes.Buffer |
|||
entry.Time = time.Now() |
|||
entry.Level = level |
|||
entry.Message = msg |
|||
|
|||
if err := entry.Logger.Hooks.Fire(level, &entry); err != nil { |
|||
entry.Logger.mu.Lock() |
|||
fmt.Fprintf(os.Stderr, "Failed to fire hook: %v\n", err) |
|||
entry.Logger.mu.Unlock() |
|||
} |
|||
buffer = bufferPool.Get().(*bytes.Buffer) |
|||
buffer.Reset() |
|||
defer bufferPool.Put(buffer) |
|||
entry.Buffer = buffer |
|||
serialized, err := entry.Logger.Formatter.Format(&entry) |
|||
entry.Buffer = nil |
|||
if err != nil { |
|||
entry.Logger.mu.Lock() |
|||
fmt.Fprintf(os.Stderr, "Failed to obtain reader, %v\n", err) |
|||
entry.Logger.mu.Unlock() |
|||
} else { |
|||
entry.Logger.mu.Lock() |
|||
_, err = entry.Logger.Out.Write(serialized) |
|||
if err != nil { |
|||
fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err) |
|||
} |
|||
entry.Logger.mu.Unlock() |
|||
} |
|||
|
|||
// To avoid Entry#log() returning a value that only would make sense for
|
|||
// panic() to use in Entry#Panic(), we avoid the allocation by checking
|
|||
// directly here.
|
|||
if level <= PanicLevel { |
|||
panic(&entry) |
|||
} |
|||
} |
|||
|
|||
func (entry *Entry) Debug(args ...interface{}) { |
|||
if entry.Logger.Level >= DebugLevel { |
|||
entry.log(DebugLevel, fmt.Sprint(args...)) |
|||
} |
|||
} |
|||
|
|||
func (entry *Entry) Print(args ...interface{}) { |
|||
entry.Info(args...) |
|||
} |
|||
|
|||
func (entry *Entry) Info(args ...interface{}) { |
|||
if entry.Logger.Level >= InfoLevel { |
|||
entry.log(InfoLevel, fmt.Sprint(args...)) |
|||
} |
|||
} |
|||
|
|||
func (entry *Entry) Warn(args ...interface{}) { |
|||
if entry.Logger.Level >= WarnLevel { |
|||
entry.log(WarnLevel, fmt.Sprint(args...)) |
|||
} |
|||
} |
|||
|
|||
func (entry *Entry) Warning(args ...interface{}) { |
|||
entry.Warn(args...) |
|||
} |
|||
|
|||
func (entry *Entry) Error(args ...interface{}) { |
|||
if entry.Logger.Level >= ErrorLevel { |
|||
entry.log(ErrorLevel, fmt.Sprint(args...)) |
|||
} |
|||
} |
|||
|
|||
func (entry *Entry) Fatal(args ...interface{}) { |
|||
if entry.Logger.Level >= FatalLevel { |
|||
entry.log(FatalLevel, fmt.Sprint(args...)) |
|||
} |
|||
Exit(1) |
|||
} |
|||
|
|||
func (entry *Entry) Panic(args ...interface{}) { |
|||
if entry.Logger.Level >= PanicLevel { |
|||
entry.log(PanicLevel, fmt.Sprint(args...)) |
|||
} |
|||
panic(fmt.Sprint(args...)) |
|||
} |
|||
|
|||
// Entry Printf family functions
|
|||
|
|||
func (entry *Entry) Debugf(format string, args ...interface{}) { |
|||
if entry.Logger.Level >= DebugLevel { |
|||
entry.Debug(fmt.Sprintf(format, args...)) |
|||
} |
|||
} |
|||
|
|||
func (entry *Entry) Infof(format string, args ...interface{}) { |
|||
if entry.Logger.Level >= InfoLevel { |
|||
entry.Info(fmt.Sprintf(format, args...)) |
|||
} |
|||
} |
|||
|
|||
func (entry *Entry) Printf(format string, args ...interface{}) { |
|||
entry.Infof(format, args...) |
|||
} |
|||
|
|||
func (entry *Entry) Warnf(format string, args ...interface{}) { |
|||
if entry.Logger.Level >= WarnLevel { |
|||
entry.Warn(fmt.Sprintf(format, args...)) |
|||
} |
|||
} |
|||
|
|||
func (entry *Entry) Warningf(format string, args ...interface{}) { |
|||
entry.Warnf(format, args...) |
|||
} |
|||
|
|||
func (entry *Entry) Errorf(format string, args ...interface{}) { |
|||
if entry.Logger.Level >= ErrorLevel { |
|||
entry.Error(fmt.Sprintf(format, args...)) |
|||
} |
|||
} |
|||
|
|||
func (entry *Entry) Fatalf(format string, args ...interface{}) { |
|||
if entry.Logger.Level >= FatalLevel { |
|||
entry.Fatal(fmt.Sprintf(format, args...)) |
|||
} |
|||
Exit(1) |
|||
} |
|||
|
|||
func (entry *Entry) Panicf(format string, args ...interface{}) { |
|||
if entry.Logger.Level >= PanicLevel { |
|||
entry.Panic(fmt.Sprintf(format, args...)) |
|||
} |
|||
} |
|||
|
|||
// Entry Println family functions
|
|||
|
|||
func (entry *Entry) Debugln(args ...interface{}) { |
|||
if entry.Logger.Level >= DebugLevel { |
|||
entry.Debug(entry.sprintlnn(args...)) |
|||
} |
|||
} |
|||
|
|||
func (entry *Entry) Infoln(args ...interface{}) { |
|||
if entry.Logger.Level >= InfoLevel { |
|||
entry.Info(entry.sprintlnn(args...)) |
|||
} |
|||
} |
|||
|
|||
func (entry *Entry) Println(args ...interface{}) { |
|||
entry.Infoln(args...) |
|||
} |
|||
|
|||
func (entry *Entry) Warnln(args ...interface{}) { |
|||
if entry.Logger.Level >= WarnLevel { |
|||
entry.Warn(entry.sprintlnn(args...)) |
|||
} |
|||
} |
|||
|
|||
func (entry *Entry) Warningln(args ...interface{}) { |
|||
entry.Warnln(args...) |
|||
} |
|||
|
|||
func (entry *Entry) Errorln(args ...interface{}) { |
|||
if entry.Logger.Level >= ErrorLevel { |
|||
entry.Error(entry.sprintlnn(args...)) |
|||
} |
|||
} |
|||
|
|||
func (entry *Entry) Fatalln(args ...interface{}) { |
|||
if entry.Logger.Level >= FatalLevel { |
|||
entry.Fatal(entry.sprintlnn(args...)) |
|||
} |
|||
Exit(1) |
|||
} |
|||
|
|||
func (entry *Entry) Panicln(args ...interface{}) { |
|||
if entry.Logger.Level >= PanicLevel { |
|||
entry.Panic(entry.sprintlnn(args...)) |
|||
} |
|||
} |
|||
|
|||
// Sprintlnn => Sprint no newline. This is to get the behavior of how
|
|||
// fmt.Sprintln where spaces are always added between operands, regardless of
|
|||
// their type. Instead of vendoring the Sprintln implementation to spare a
|
|||
// string allocation, we do the simplest thing.
|
|||
func (entry *Entry) sprintlnn(args ...interface{}) string { |
|||
msg := fmt.Sprintln(args...) |
|||
return msg[:len(msg)-1] |
|||
} |
@ -0,0 +1,193 @@ |
|||
package logrus |
|||
|
|||
import ( |
|||
"io" |
|||
) |
|||
|
|||
var ( |
|||
// std is the name of the standard logger in stdlib `log`
|
|||
std = New() |
|||
) |
|||
|
|||
func StandardLogger() *Logger { |
|||
return std |
|||
} |
|||
|
|||
// SetOutput sets the standard logger output.
|
|||
func SetOutput(out io.Writer) { |
|||
std.mu.Lock() |
|||
defer std.mu.Unlock() |
|||
std.Out = out |
|||
} |
|||
|
|||
// SetFormatter sets the standard logger formatter.
|
|||
func SetFormatter(formatter Formatter) { |
|||
std.mu.Lock() |
|||
defer std.mu.Unlock() |
|||
std.Formatter = formatter |
|||
} |
|||
|
|||
// SetLevel sets the standard logger level.
|
|||
func SetLevel(level Level) { |
|||
std.mu.Lock() |
|||
defer std.mu.Unlock() |
|||
std.Level = level |
|||
} |
|||
|
|||
// GetLevel returns the standard logger level.
|
|||
func GetLevel() Level { |
|||
std.mu.Lock() |
|||
defer std.mu.Unlock() |
|||
return std.Level |
|||
} |
|||
|
|||
// AddHook adds a hook to the standard logger hooks.
|
|||
func AddHook(hook Hook) { |
|||
std.mu.Lock() |
|||
defer std.mu.Unlock() |
|||
std.Hooks.Add(hook) |
|||
} |
|||
|
|||
// WithError creates an entry from the standard logger and adds an error to it, using the value defined in ErrorKey as key.
|
|||
func WithError(err error) *Entry { |
|||
return std.WithField(ErrorKey, err) |
|||
} |
|||
|
|||
// WithField creates an entry from the standard logger and adds a field to
|
|||
// it. If you want multiple fields, use `WithFields`.
|
|||
//
|
|||
// Note that it doesn't log until you call Debug, Print, Info, Warn, Fatal
|
|||
// or Panic on the Entry it returns.
|
|||
func WithField(key string, value interface{}) *Entry { |
|||
return std.WithField(key, value) |
|||
} |
|||
|
|||
// WithFields creates an entry from the standard logger and adds multiple
|
|||
// fields to it. This is simply a helper for `WithField`, invoking it
|
|||
// once for each field.
|
|||
//
|
|||
// Note that it doesn't log until you call Debug, Print, Info, Warn, Fatal
|
|||
// or Panic on the Entry it returns.
|
|||
func WithFields(fields Fields) *Entry { |
|||
return std.WithFields(fields) |
|||
} |
|||
|
|||
// Debug logs a message at level Debug on the standard logger.
|
|||
func Debug(args ...interface{}) { |
|||
std.Debug(args...) |
|||
} |
|||
|
|||
// Print logs a message at level Info on the standard logger.
|
|||
func Print(args ...interface{}) { |
|||
std.Print(args...) |
|||
} |
|||
|
|||
// Info logs a message at level Info on the standard logger.
|
|||
func Info(args ...interface{}) { |
|||
std.Info(args...) |
|||
} |
|||
|
|||
// Warn logs a message at level Warn on the standard logger.
|
|||
func Warn(args ...interface{}) { |
|||
std.Warn(args...) |
|||
} |
|||
|
|||
// Warning logs a message at level Warn on the standard logger.
|
|||
func Warning(args ...interface{}) { |
|||
std.Warning(args...) |
|||
} |
|||
|
|||
// Error logs a message at level Error on the standard logger.
|
|||
func Error(args ...interface{}) { |
|||
std.Error(args...) |
|||
} |
|||
|
|||
// Panic logs a message at level Panic on the standard logger.
|
|||
func Panic(args ...interface{}) { |
|||
std.Panic(args...) |
|||
} |
|||
|
|||
// Fatal logs a message at level Fatal on the standard logger.
|
|||
func Fatal(args ...interface{}) { |
|||
std.Fatal(args...) |
|||
} |
|||
|
|||
// Debugf logs a message at level Debug on the standard logger.
|
|||
func Debugf(format string, args ...interface{}) { |
|||
std.Debugf(format, args...) |
|||
} |
|||
|
|||
// Printf logs a message at level Info on the standard logger.
|
|||
func Printf(format string, args ...interface{}) { |
|||
std.Printf(format, args...) |
|||
} |
|||
|
|||
// Infof logs a message at level Info on the standard logger.
|
|||
func Infof(format string, args ...interface{}) { |
|||
std.Infof(format, args...) |
|||
} |
|||
|
|||
// Warnf logs a message at level Warn on the standard logger.
|
|||
func Warnf(format string, args ...interface{}) { |
|||
std.Warnf(format, args...) |
|||
} |
|||
|
|||
// Warningf logs a message at level Warn on the standard logger.
|
|||
func Warningf(format string, args ...interface{}) { |
|||
std.Warningf(format, args...) |
|||
} |
|||
|
|||
// Errorf logs a message at level Error on the standard logger.
|
|||
func Errorf(format string, args ...interface{}) { |
|||
std.Errorf(format, args...) |
|||
} |
|||
|
|||
// Panicf logs a message at level Panic on the standard logger.
|
|||
func Panicf(format string, args ...interface{}) { |
|||
std.Panicf(format, args...) |
|||
} |
|||
|
|||
// Fatalf logs a message at level Fatal on the standard logger.
|
|||
func Fatalf(format string, args ...interface{}) { |
|||
std.Fatalf(format, args...) |
|||
} |
|||
|
|||
// Debugln logs a message at level Debug on the standard logger.
|
|||
func Debugln(args ...interface{}) { |
|||
std.Debugln(args...) |
|||
} |
|||
|
|||
// Println logs a message at level Info on the standard logger.
|
|||
func Println(args ...interface{}) { |
|||
std.Println(args...) |
|||
} |
|||
|
|||
// Infoln logs a message at level Info on the standard logger.
|
|||
func Infoln(args ...interface{}) { |
|||
std.Infoln(args...) |
|||
} |
|||
|
|||
// Warnln logs a message at level Warn on the standard logger.
|
|||
func Warnln(args ...interface{}) { |
|||
std.Warnln(args...) |
|||
} |
|||
|
|||
// Warningln logs a message at level Warn on the standard logger.
|
|||
func Warningln(args ...interface{}) { |
|||
std.Warningln(args...) |
|||
} |
|||
|
|||
// Errorln logs a message at level Error on the standard logger.
|
|||
func Errorln(args ...interface{}) { |
|||
std.Errorln(args...) |
|||
} |
|||
|
|||
// Panicln logs a message at level Panic on the standard logger.
|
|||
func Panicln(args ...interface{}) { |
|||
std.Panicln(args...) |
|||
} |
|||
|
|||
// Fatalln logs a message at level Fatal on the standard logger.
|
|||
func Fatalln(args ...interface{}) { |
|||
std.Fatalln(args...) |
|||
} |
@ -0,0 +1,45 @@ |
|||
package logrus |
|||
|
|||
import "time" |
|||
|
|||
const DefaultTimestampFormat = time.RFC3339 |
|||
|
|||
// The Formatter interface is used to implement a custom Formatter. It takes an
|
|||
// `Entry`. It exposes all the fields, including the default ones:
|
|||
//
|
|||
// * `entry.Data["msg"]`. The message passed from Info, Warn, Error ..
|
|||
// * `entry.Data["time"]`. The timestamp.
|
|||
// * `entry.Data["level"]. The level the entry was logged at.
|
|||
//
|
|||
// Any additional fields added with `WithField` or `WithFields` are also in
|
|||
// `entry.Data`. Format is expected to return an array of bytes which are then
|
|||
// logged to `logger.Out`.
|
|||
type Formatter interface { |
|||
Format(*Entry) ([]byte, error) |
|||
} |
|||
|
|||
// This is to not silently overwrite `time`, `msg` and `level` fields when
|
|||
// dumping it. If this code wasn't there doing:
|
|||
//
|
|||
// logrus.WithField("level", 1).Info("hello")
|
|||
//
|
|||
// Would just silently drop the user provided level. Instead with this code
|
|||
// it'll logged as:
|
|||
//
|
|||
// {"level": "info", "fields.level": 1, "msg": "hello", "time": "..."}
|
|||
//
|
|||
// It's not exported because it's still using Data in an opinionated way. It's to
|
|||
// avoid code duplication between the two default formatters.
|
|||
func prefixFieldClashes(data Fields) { |
|||
if t, ok := data["time"]; ok { |
|||
data["fields.time"] = t |
|||
} |
|||
|
|||
if m, ok := data["msg"]; ok { |
|||
data["fields.msg"] = m |
|||
} |
|||
|
|||
if l, ok := data["level"]; ok { |
|||
data["fields.level"] = l |
|||
} |
|||
} |
@ -0,0 +1,34 @@ |
|||
package logrus |
|||
|
|||
// A hook to be fired when logging on the logging levels returned from
|
|||
// `Levels()` on your implementation of the interface. Note that this is not
|
|||
// fired in a goroutine or a channel with workers, you should handle such
|
|||
// functionality yourself if your call is non-blocking and you don't wish for
|
|||
// the logging calls for levels returned from `Levels()` to block.
|
|||
type Hook interface { |
|||
Levels() []Level |
|||
Fire(*Entry) error |
|||
} |
|||
|
|||
// Internal type for storing the hooks on a logger instance.
|
|||
type LevelHooks map[Level][]Hook |
|||
|
|||
// Add a hook to an instance of logger. This is called with
|
|||
// `log.Hooks.Add(new(MyHook))` where `MyHook` implements the `Hook` interface.
|
|||
func (hooks LevelHooks) Add(hook Hook) { |
|||
for _, level := range hook.Levels() { |
|||
hooks[level] = append(hooks[level], hook) |
|||
} |
|||
} |
|||
|
|||
// Fire all the hooks for the passed level. Used by `entry.log` to fire
|
|||
// appropriate hooks for a log entry.
|
|||
func (hooks LevelHooks) Fire(level Level, entry *Entry) error { |
|||
for _, hook := range hooks[level] { |
|||
if err := hook.Fire(entry); err != nil { |
|||
return err |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
@ -0,0 +1,41 @@ |
|||
package logrus |
|||
|
|||
import ( |
|||
"encoding/json" |
|||
"fmt" |
|||
) |
|||
|
|||
type JSONFormatter struct { |
|||
// TimestampFormat sets the format used for marshaling timestamps.
|
|||
TimestampFormat string |
|||
} |
|||
|
|||
func (f *JSONFormatter) Format(entry *Entry) ([]byte, error) { |
|||
data := make(Fields, len(entry.Data)+3) |
|||
for k, v := range entry.Data { |
|||
switch v := v.(type) { |
|||
case error: |
|||
// Otherwise errors are ignored by `encoding/json`
|
|||
// https://github.com/Sirupsen/logrus/issues/137
|
|||
data[k] = v.Error() |
|||
default: |
|||
data[k] = v |
|||
} |
|||
} |
|||
prefixFieldClashes(data) |
|||
|
|||
timestampFormat := f.TimestampFormat |
|||
if timestampFormat == "" { |
|||
timestampFormat = DefaultTimestampFormat |
|||
} |
|||
|
|||
data["time"] = entry.Time.Format(timestampFormat) |
|||
data["msg"] = entry.Message |
|||
data["level"] = entry.Level.String() |
|||
|
|||
serialized, err := json.Marshal(data) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("Failed to marshal fields to JSON, %v", err) |
|||
} |
|||
return append(serialized, '\n'), nil |
|||
} |
@ -0,0 +1,308 @@ |
|||
package logrus |
|||
|
|||
import ( |
|||
"io" |
|||
"os" |
|||
"sync" |
|||
) |
|||
|
|||
type Logger struct { |
|||
// The logs are `io.Copy`'d to this in a mutex. It's common to set this to a
|
|||
// file, or leave it default which is `os.Stderr`. You can also set this to
|
|||
// something more adventorous, such as logging to Kafka.
|
|||
Out io.Writer |
|||
// Hooks for the logger instance. These allow firing events based on logging
|
|||
// levels and log entries. For example, to send errors to an error tracking
|
|||
// service, log to StatsD or dump the core on fatal errors.
|
|||
Hooks LevelHooks |
|||
// All log entries pass through the formatter before logged to Out. The
|
|||
// included formatters are `TextFormatter` and `JSONFormatter` for which
|
|||
// TextFormatter is the default. In development (when a TTY is attached) it
|
|||
// logs with colors, but to a file it wouldn't. You can easily implement your
|
|||
// own that implements the `Formatter` interface, see the `README` or included
|
|||
// formatters for examples.
|
|||
Formatter Formatter |
|||
// The logging level the logger should log at. This is typically (and defaults
|
|||
// to) `logrus.Info`, which allows Info(), Warn(), Error() and Fatal() to be
|
|||
// logged. `logrus.Debug` is useful in
|
|||
Level Level |
|||
// Used to sync writing to the log. Locking is enabled by Default
|
|||
mu MutexWrap |
|||
// Reusable empty entry
|
|||
entryPool sync.Pool |
|||
} |
|||
|
|||
type MutexWrap struct { |
|||
lock sync.Mutex |
|||
disabled bool |
|||
} |
|||
|
|||
func (mw *MutexWrap) Lock() { |
|||
if !mw.disabled { |
|||
mw.lock.Lock() |
|||
} |
|||
} |
|||
|
|||
func (mw *MutexWrap) Unlock() { |
|||
if !mw.disabled { |
|||
mw.lock.Unlock() |
|||
} |
|||
} |
|||
|
|||
func (mw *MutexWrap) Disable() { |
|||
mw.disabled = true |
|||
} |
|||
|
|||
// Creates a new logger. Configuration should be set by changing `Formatter`,
|
|||
// `Out` and `Hooks` directly on the default logger instance. You can also just
|
|||
// instantiate your own:
|
|||
//
|
|||
// var log = &Logger{
|
|||
// Out: os.Stderr,
|
|||
// Formatter: new(JSONFormatter),
|
|||
// Hooks: make(LevelHooks),
|
|||
// Level: logrus.DebugLevel,
|
|||
// }
|
|||
//
|
|||
// It's recommended to make this a global instance called `log`.
|
|||
func New() *Logger { |
|||
return &Logger{ |
|||
Out: os.Stderr, |
|||
Formatter: new(TextFormatter), |
|||
Hooks: make(LevelHooks), |
|||
Level: InfoLevel, |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) newEntry() *Entry { |
|||
entry, ok := logger.entryPool.Get().(*Entry) |
|||
if ok { |
|||
return entry |
|||
} |
|||
return NewEntry(logger) |
|||
} |
|||
|
|||
func (logger *Logger) releaseEntry(entry *Entry) { |
|||
logger.entryPool.Put(entry) |
|||
} |
|||
|
|||
// Adds a field to the log entry, note that it doesn't log until you call
|
|||
// Debug, Print, Info, Warn, Fatal or Panic. It only creates a log entry.
|
|||
// If you want multiple fields, use `WithFields`.
|
|||
func (logger *Logger) WithField(key string, value interface{}) *Entry { |
|||
entry := logger.newEntry() |
|||
defer logger.releaseEntry(entry) |
|||
return entry.WithField(key, value) |
|||
} |
|||
|
|||
// Adds a struct of fields to the log entry. All it does is call `WithField` for
|
|||
// each `Field`.
|
|||
func (logger *Logger) WithFields(fields Fields) *Entry { |
|||
entry := logger.newEntry() |
|||
defer logger.releaseEntry(entry) |
|||
return entry.WithFields(fields) |
|||
} |
|||
|
|||
// Add an error as single field to the log entry. All it does is call
|
|||
// `WithError` for the given `error`.
|
|||
func (logger *Logger) WithError(err error) *Entry { |
|||
entry := logger.newEntry() |
|||
defer logger.releaseEntry(entry) |
|||
return entry.WithError(err) |
|||
} |
|||
|
|||
func (logger *Logger) Debugf(format string, args ...interface{}) { |
|||
if logger.Level >= DebugLevel { |
|||
entry := logger.newEntry() |
|||
entry.Debugf(format, args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) Infof(format string, args ...interface{}) { |
|||
if logger.Level >= InfoLevel { |
|||
entry := logger.newEntry() |
|||
entry.Infof(format, args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) Printf(format string, args ...interface{}) { |
|||
entry := logger.newEntry() |
|||
entry.Printf(format, args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
|
|||
func (logger *Logger) Warnf(format string, args ...interface{}) { |
|||
if logger.Level >= WarnLevel { |
|||
entry := logger.newEntry() |
|||
entry.Warnf(format, args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) Warningf(format string, args ...interface{}) { |
|||
if logger.Level >= WarnLevel { |
|||
entry := logger.newEntry() |
|||
entry.Warnf(format, args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) Errorf(format string, args ...interface{}) { |
|||
if logger.Level >= ErrorLevel { |
|||
entry := logger.newEntry() |
|||
entry.Errorf(format, args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) Fatalf(format string, args ...interface{}) { |
|||
if logger.Level >= FatalLevel { |
|||
entry := logger.newEntry() |
|||
entry.Fatalf(format, args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
Exit(1) |
|||
} |
|||
|
|||
func (logger *Logger) Panicf(format string, args ...interface{}) { |
|||
if logger.Level >= PanicLevel { |
|||
entry := logger.newEntry() |
|||
entry.Panicf(format, args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) Debug(args ...interface{}) { |
|||
if logger.Level >= DebugLevel { |
|||
entry := logger.newEntry() |
|||
entry.Debug(args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) Info(args ...interface{}) { |
|||
if logger.Level >= InfoLevel { |
|||
entry := logger.newEntry() |
|||
entry.Info(args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) Print(args ...interface{}) { |
|||
entry := logger.newEntry() |
|||
entry.Info(args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
|
|||
func (logger *Logger) Warn(args ...interface{}) { |
|||
if logger.Level >= WarnLevel { |
|||
entry := logger.newEntry() |
|||
entry.Warn(args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) Warning(args ...interface{}) { |
|||
if logger.Level >= WarnLevel { |
|||
entry := logger.newEntry() |
|||
entry.Warn(args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) Error(args ...interface{}) { |
|||
if logger.Level >= ErrorLevel { |
|||
entry := logger.newEntry() |
|||
entry.Error(args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) Fatal(args ...interface{}) { |
|||
if logger.Level >= FatalLevel { |
|||
entry := logger.newEntry() |
|||
entry.Fatal(args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
Exit(1) |
|||
} |
|||
|
|||
func (logger *Logger) Panic(args ...interface{}) { |
|||
if logger.Level >= PanicLevel { |
|||
entry := logger.newEntry() |
|||
entry.Panic(args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) Debugln(args ...interface{}) { |
|||
if logger.Level >= DebugLevel { |
|||
entry := logger.newEntry() |
|||
entry.Debugln(args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) Infoln(args ...interface{}) { |
|||
if logger.Level >= InfoLevel { |
|||
entry := logger.newEntry() |
|||
entry.Infoln(args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) Println(args ...interface{}) { |
|||
entry := logger.newEntry() |
|||
entry.Println(args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
|
|||
func (logger *Logger) Warnln(args ...interface{}) { |
|||
if logger.Level >= WarnLevel { |
|||
entry := logger.newEntry() |
|||
entry.Warnln(args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) Warningln(args ...interface{}) { |
|||
if logger.Level >= WarnLevel { |
|||
entry := logger.newEntry() |
|||
entry.Warnln(args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) Errorln(args ...interface{}) { |
|||
if logger.Level >= ErrorLevel { |
|||
entry := logger.newEntry() |
|||
entry.Errorln(args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
func (logger *Logger) Fatalln(args ...interface{}) { |
|||
if logger.Level >= FatalLevel { |
|||
entry := logger.newEntry() |
|||
entry.Fatalln(args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
Exit(1) |
|||
} |
|||
|
|||
func (logger *Logger) Panicln(args ...interface{}) { |
|||
if logger.Level >= PanicLevel { |
|||
entry := logger.newEntry() |
|||
entry.Panicln(args...) |
|||
logger.releaseEntry(entry) |
|||
} |
|||
} |
|||
|
|||
//When file is opened with appending mode, it's safe to
|
|||
//write concurrently to a file (within 4k message on Linux).
|
|||
//In these cases user can choose to disable the lock.
|
|||
func (logger *Logger) SetNoLock() { |
|||
logger.mu.Disable() |
|||
} |
@ -0,0 +1,143 @@ |
|||
package logrus |
|||
|
|||
import ( |
|||
"fmt" |
|||
"log" |
|||
"strings" |
|||
) |
|||
|
|||
// Fields type, used to pass to `WithFields`.
|
|||
type Fields map[string]interface{} |
|||
|
|||
// Level type
|
|||
type Level uint8 |
|||
|
|||
// Convert the Level to a string. E.g. PanicLevel becomes "panic".
|
|||
func (level Level) String() string { |
|||
switch level { |
|||
case DebugLevel: |
|||
return "debug" |
|||
case InfoLevel: |
|||
return "info" |
|||
case WarnLevel: |
|||
return "warning" |
|||
case ErrorLevel: |
|||
return "error" |
|||
case FatalLevel: |
|||
return "fatal" |
|||
case PanicLevel: |
|||
return "panic" |
|||
} |
|||
|
|||
return "unknown" |
|||
} |
|||
|
|||
// ParseLevel takes a string level and returns the Logrus log level constant.
|
|||
func ParseLevel(lvl string) (Level, error) { |
|||
switch strings.ToLower(lvl) { |
|||
case "panic": |
|||
return PanicLevel, nil |
|||
case "fatal": |
|||
return FatalLevel, nil |
|||
case "error": |
|||
return ErrorLevel, nil |
|||
case "warn", "warning": |
|||
return WarnLevel, nil |
|||
case "info": |
|||
return InfoLevel, nil |
|||
case "debug": |
|||
return DebugLevel, nil |
|||
} |
|||
|
|||
var l Level |
|||
return l, fmt.Errorf("not a valid logrus Level: %q", lvl) |
|||
} |
|||
|
|||
// A constant exposing all logging levels
|
|||
var AllLevels = []Level{ |
|||
PanicLevel, |
|||
FatalLevel, |
|||
ErrorLevel, |
|||
WarnLevel, |
|||
InfoLevel, |
|||
DebugLevel, |
|||
} |
|||
|
|||
// These are the different logging levels. You can set the logging level to log
|
|||
// on your instance of logger, obtained with `logrus.New()`.
|
|||
const ( |
|||
// PanicLevel level, highest level of severity. Logs and then calls panic with the
|
|||
// message passed to Debug, Info, ...
|
|||
PanicLevel Level = iota |
|||
// FatalLevel level. Logs and then calls `os.Exit(1)`. It will exit even if the
|
|||
// logging level is set to Panic.
|
|||
FatalLevel |
|||
// ErrorLevel level. Logs. Used for errors that should definitely be noted.
|
|||
// Commonly used for hooks to send errors to an error tracking service.
|
|||
ErrorLevel |
|||
// WarnLevel level. Non-critical entries that deserve eyes.
|
|||
WarnLevel |
|||
// InfoLevel level. General operational entries about what's going on inside the
|
|||
// application.
|
|||
InfoLevel |
|||
// DebugLevel level. Usually only enabled when debugging. Very verbose logging.
|
|||
DebugLevel |
|||
) |
|||
|
|||
// Won't compile if StdLogger can't be realized by a log.Logger
|
|||
var ( |
|||
_ StdLogger = &log.Logger{} |
|||
_ StdLogger = &Entry{} |
|||
_ StdLogger = &Logger{} |
|||
) |
|||
|
|||
// StdLogger is what your logrus-enabled library should take, that way
|
|||
// it'll accept a stdlib logger and a logrus logger. There's no standard
|
|||
// interface, this is the closest we get, unfortunately.
|
|||
type StdLogger interface { |
|||
Print(...interface{}) |
|||
Printf(string, ...interface{}) |
|||
Println(...interface{}) |
|||
|
|||
Fatal(...interface{}) |
|||
Fatalf(string, ...interface{}) |
|||
Fatalln(...interface{}) |
|||
|
|||
Panic(...interface{}) |
|||
Panicf(string, ...interface{}) |
|||
Panicln(...interface{}) |
|||
} |
|||
|
|||
// The FieldLogger interface generalizes the Entry and Logger types
|
|||
type FieldLogger interface { |
|||
WithField(key string, value interface{}) *Entry |
|||
WithFields(fields Fields) *Entry |
|||
WithError(err error) *Entry |
|||
|
|||
Debugf(format string, args ...interface{}) |
|||
Infof(format string, args ...interface{}) |
|||
Printf(format string, args ...interface{}) |
|||
Warnf(format string, args ...interface{}) |
|||
Warningf(format string, args ...interface{}) |
|||
Errorf(format string, args ...interface{}) |
|||
Fatalf(format string, args ...interface{}) |
|||
Panicf(format string, args ...interface{}) |
|||
|
|||
Debug(args ...interface{}) |
|||
Info(args ...interface{}) |
|||
Print(args ...interface{}) |
|||
Warn(args ...interface{}) |
|||
Warning(args ...interface{}) |
|||
Error(args ...interface{}) |
|||
Fatal(args ...interface{}) |
|||
Panic(args ...interface{}) |
|||
|
|||
Debugln(args ...interface{}) |
|||
Infoln(args ...interface{}) |
|||
Println(args ...interface{}) |
|||
Warnln(args ...interface{}) |
|||
Warningln(args ...interface{}) |
|||
Errorln(args ...interface{}) |
|||
Fatalln(args ...interface{}) |
|||
Panicln(args ...interface{}) |
|||
} |
@ -0,0 +1,8 @@ |
|||
// +build appengine
|
|||
|
|||
package logrus |
|||
|
|||
// IsTerminal returns true if stderr's file descriptor is a terminal.
|
|||
func IsTerminal() bool { |
|||
return true |
|||
} |
@ -0,0 +1,10 @@ |
|||
// +build darwin freebsd openbsd netbsd dragonfly
|
|||
// +build !appengine
|
|||
|
|||
package logrus |
|||
|
|||
import "syscall" |
|||
|
|||
const ioctlReadTermios = syscall.TIOCGETA |
|||
|
|||
type Termios syscall.Termios |
@ -0,0 +1,14 @@ |
|||
// Based on ssh/terminal:
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// +build !appengine
|
|||
|
|||
package logrus |
|||
|
|||
import "syscall" |
|||
|
|||
const ioctlReadTermios = syscall.TCGETS |
|||
|
|||
type Termios syscall.Termios |
@ -0,0 +1,22 @@ |
|||
// Based on ssh/terminal:
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// +build linux darwin freebsd openbsd netbsd dragonfly
|
|||
// +build !appengine
|
|||
|
|||
package logrus |
|||
|
|||
import ( |
|||
"syscall" |
|||
"unsafe" |
|||
) |
|||
|
|||
// IsTerminal returns true if stderr's file descriptor is a terminal.
|
|||
func IsTerminal() bool { |
|||
fd := syscall.Stderr |
|||
var termios Termios |
|||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) |
|||
return err == 0 |
|||
} |
@ -0,0 +1,15 @@ |
|||
// +build solaris,!appengine
|
|||
|
|||
package logrus |
|||
|
|||
import ( |
|||
"os" |
|||
|
|||
"golang.org/x/sys/unix" |
|||
) |
|||
|
|||
// IsTerminal returns true if the given file descriptor is a terminal.
|
|||
func IsTerminal() bool { |
|||
_, err := unix.IoctlGetTermios(int(os.Stdout.Fd()), unix.TCGETA) |
|||
return err == nil |
|||
} |
@ -0,0 +1,27 @@ |
|||
// Based on ssh/terminal:
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// +build windows,!appengine
|
|||
|
|||
package logrus |
|||
|
|||
import ( |
|||
"syscall" |
|||
"unsafe" |
|||
) |
|||
|
|||
var kernel32 = syscall.NewLazyDLL("kernel32.dll") |
|||
|
|||
var ( |
|||
procGetConsoleMode = kernel32.NewProc("GetConsoleMode") |
|||
) |
|||
|
|||
// IsTerminal returns true if stderr's file descriptor is a terminal.
|
|||
func IsTerminal() bool { |
|||
fd := syscall.Stderr |
|||
var st uint32 |
|||
r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) |
|||
return r != 0 && e == 0 |
|||
} |
@ -0,0 +1,165 @@ |
|||
package logrus |
|||
|
|||
import ( |
|||
"bytes" |
|||
"fmt" |
|||
"runtime" |
|||
"sort" |
|||
"strings" |
|||
"time" |
|||
) |
|||
|
|||
const ( |
|||
nocolor = 0 |
|||
red = 31 |
|||
green = 32 |
|||
yellow = 33 |
|||
blue = 34 |
|||
gray = 37 |
|||
) |
|||
|
|||
var ( |
|||
baseTimestamp time.Time |
|||
isTerminal bool |
|||
) |
|||
|
|||
func init() { |
|||
baseTimestamp = time.Now() |
|||
isTerminal = IsTerminal() |
|||
} |
|||
|
|||
func miniTS() int { |
|||
return int(time.Since(baseTimestamp) / time.Second) |
|||
} |
|||
|
|||
type TextFormatter struct { |
|||
// Set to true to bypass checking for a TTY before outputting colors.
|
|||
ForceColors bool |
|||
|
|||
// Force disabling colors.
|
|||
DisableColors bool |
|||
|
|||
// Disable timestamp logging. useful when output is redirected to logging
|
|||
// system that already adds timestamps.
|
|||
DisableTimestamp bool |
|||
|
|||
// Enable logging the full timestamp when a TTY is attached instead of just
|
|||
// the time passed since beginning of execution.
|
|||
FullTimestamp bool |
|||
|
|||
// TimestampFormat to use for display when a full timestamp is printed
|
|||
TimestampFormat string |
|||
|
|||
// The fields are sorted by default for a consistent output. For applications
|
|||
// that log extremely frequently and don't use the JSON formatter this may not
|
|||
// be desired.
|
|||
DisableSorting bool |
|||
} |
|||
|
|||
func (f *TextFormatter) Format(entry *Entry) ([]byte, error) { |
|||
var b *bytes.Buffer |
|||
var keys []string = make([]string, 0, len(entry.Data)) |
|||
for k := range entry.Data { |
|||
keys = append(keys, k) |
|||
} |
|||
|
|||
if !f.DisableSorting { |
|||
sort.Strings(keys) |
|||
} |
|||
if entry.Buffer != nil { |
|||
b = entry.Buffer |
|||
} else { |
|||
b = &bytes.Buffer{} |
|||
} |
|||
|
|||
prefixFieldClashes(entry.Data) |
|||
|
|||
isColorTerminal := isTerminal && (runtime.GOOS != "windows") |
|||
isColored := (f.ForceColors || isColorTerminal) && !f.DisableColors |
|||
|
|||
timestampFormat := f.TimestampFormat |
|||
if timestampFormat == "" { |
|||
timestampFormat = DefaultTimestampFormat |
|||
} |
|||
if isColored { |
|||
f.printColored(b, entry, keys, timestampFormat) |
|||
} else { |
|||
if !f.DisableTimestamp { |
|||
f.appendKeyValue(b, "time", entry.Time.Format(timestampFormat)) |
|||
} |
|||
f.appendKeyValue(b, "level", entry.Level.String()) |
|||
if entry.Message != "" { |
|||
f.appendKeyValue(b, "msg", entry.Message) |
|||
} |
|||
for _, key := range keys { |
|||
f.appendKeyValue(b, key, entry.Data[key]) |
|||
} |
|||
} |
|||
|
|||
b.WriteByte('\n') |
|||
return b.Bytes(), nil |
|||
} |
|||
|
|||
func (f *TextFormatter) printColored(b *bytes.Buffer, entry *Entry, keys []string, timestampFormat string) { |
|||
var levelColor int |
|||
switch entry.Level { |
|||
case DebugLevel: |
|||
levelColor = gray |
|||
case WarnLevel: |
|||
levelColor = yellow |
|||
case ErrorLevel, FatalLevel, PanicLevel: |
|||
levelColor = red |
|||
default: |
|||
levelColor = blue |
|||
} |
|||
|
|||
levelText := strings.ToUpper(entry.Level.String())[0:4] |
|||
|
|||
if !f.FullTimestamp { |
|||
fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m[%04d] %-44s ", levelColor, levelText, miniTS(), entry.Message) |
|||
} else { |
|||
fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m[%s] %-44s ", levelColor, levelText, entry.Time.Format(timestampFormat), entry.Message) |
|||
} |
|||
for _, k := range keys { |
|||
v := entry.Data[k] |
|||
fmt.Fprintf(b, " \x1b[%dm%s\x1b[0m=%+v", levelColor, k, v) |
|||
} |
|||
} |
|||
|
|||
func needsQuoting(text string) bool { |
|||
for _, ch := range text { |
|||
if !((ch >= 'a' && ch <= 'z') || |
|||
(ch >= 'A' && ch <= 'Z') || |
|||
(ch >= '0' && ch <= '9') || |
|||
ch == '-' || ch == '.') { |
|||
return true |
|||
} |
|||
} |
|||
return false |
|||
} |
|||
|
|||
func (f *TextFormatter) appendKeyValue(b *bytes.Buffer, key string, value interface{}) { |
|||
|
|||
b.WriteString(key) |
|||
b.WriteByte('=') |
|||
|
|||
switch value := value.(type) { |
|||
case string: |
|||
if !needsQuoting(value) { |
|||
b.WriteString(value) |
|||
} else { |
|||
fmt.Fprintf(b, "%q", value) |
|||
} |
|||
case error: |
|||
errmsg := value.Error() |
|||
if !needsQuoting(errmsg) { |
|||
b.WriteString(errmsg) |
|||
} else { |
|||
fmt.Fprintf(b, "%q", value) |
|||
} |
|||
default: |
|||
fmt.Fprint(b, value) |
|||
} |
|||
|
|||
b.WriteByte(' ') |
|||
} |
@ -0,0 +1,53 @@ |
|||
package logrus |
|||
|
|||
import ( |
|||
"bufio" |
|||
"io" |
|||
"runtime" |
|||
) |
|||
|
|||
func (logger *Logger) Writer() *io.PipeWriter { |
|||
return logger.WriterLevel(InfoLevel) |
|||
} |
|||
|
|||
func (logger *Logger) WriterLevel(level Level) *io.PipeWriter { |
|||
reader, writer := io.Pipe() |
|||
|
|||
var printFunc func(args ...interface{}) |
|||
switch level { |
|||
case DebugLevel: |
|||
printFunc = logger.Debug |
|||
case InfoLevel: |
|||
printFunc = logger.Info |
|||
case WarnLevel: |
|||
printFunc = logger.Warn |
|||
case ErrorLevel: |
|||
printFunc = logger.Error |
|||
case FatalLevel: |
|||
printFunc = logger.Fatal |
|||
case PanicLevel: |
|||
printFunc = logger.Panic |
|||
default: |
|||
printFunc = logger.Print |
|||
} |
|||
|
|||
go logger.writerScanner(reader, printFunc) |
|||
runtime.SetFinalizer(writer, writerFinalizer) |
|||
|
|||
return writer |
|||
} |
|||
|
|||
func (logger *Logger) writerScanner(reader *io.PipeReader, printFunc func(args ...interface{})) { |
|||
scanner := bufio.NewScanner(reader) |
|||
for scanner.Scan() { |
|||
printFunc(scanner.Text()) |
|||
} |
|||
if err := scanner.Err(); err != nil { |
|||
logger.Errorf("Error while reading from Writer: %s", err) |
|||
} |
|||
reader.Close() |
|||
} |
|||
|
|||
func writerFinalizer(writer *io.PipeWriter) { |
|||
writer.Close() |
|||
} |
@ -0,0 +1,202 @@ |
|||
|
|||
Apache License |
|||
Version 2.0, January 2004 |
|||
http://www.apache.org/licenses/ |
|||
|
|||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION |
|||
|
|||
1. Definitions. |
|||
|
|||
"License" shall mean the terms and conditions for use, reproduction, |
|||
and distribution as defined by Sections 1 through 9 of this document. |
|||
|
|||
"Licensor" shall mean the copyright owner or entity authorized by |
|||
the copyright owner that is granting the License. |
|||
|
|||
"Legal Entity" shall mean the union of the acting entity and all |
|||
other entities that control, are controlled by, or are under common |
|||
control with that entity. For the purposes of this definition, |
|||
"control" means (i) the power, direct or indirect, to cause the |
|||
direction or management of such entity, whether by contract or |
|||
otherwise, or (ii) ownership of fifty percent (50%) or more of the |
|||
outstanding shares, or (iii) beneficial ownership of such entity. |
|||
|
|||
"You" (or "Your") shall mean an individual or Legal Entity |
|||
exercising permissions granted by this License. |
|||
|
|||
"Source" form shall mean the preferred form for making modifications, |
|||
including but not limited to software source code, documentation |
|||
source, and configuration files. |
|||
|
|||
"Object" form shall mean any form resulting from mechanical |
|||
transformation or translation of a Source form, including but |
|||
not limited to compiled object code, generated documentation, |
|||
and conversions to other media types. |
|||
|
|||
"Work" shall mean the work of authorship, whether in Source or |
|||
Object form, made available under the License, as indicated by a |
|||
copyright notice that is included in or attached to the work |
|||
(an example is provided in the Appendix below). |
|||
|
|||
"Derivative Works" shall mean any work, whether in Source or Object |
|||
form, that is based on (or derived from) the Work and for which the |
|||
editorial revisions, annotations, elaborations, or other modifications |
|||
represent, as a whole, an original work of authorship. For the purposes |
|||
of this License, Derivative Works shall not include works that remain |
|||
separable from, or merely link (or bind by name) to the interfaces of, |
|||
the Work and Derivative Works thereof. |
|||
|
|||
"Contribution" shall mean any work of authorship, including |
|||
the original version of the Work and any modifications or additions |
|||
to that Work or Derivative Works thereof, that is intentionally |
|||
submitted to Licensor for inclusion in the Work by the copyright owner |
|||
or by an individual or Legal Entity authorized to submit on behalf of |
|||
the copyright owner. For the purposes of this definition, "submitted" |
|||
means any form of electronic, verbal, or written communication sent |
|||
to the Licensor or its representatives, including but not limited to |
|||
communication on electronic mailing lists, source code control systems, |
|||
and issue tracking systems that are managed by, or on behalf of, the |
|||
Licensor for the purpose of discussing and improving the Work, but |
|||
excluding communication that is conspicuously marked or otherwise |
|||
designated in writing by the copyright owner as "Not a Contribution." |
|||
|
|||
"Contributor" shall mean Licensor and any individual or Legal Entity |
|||
on behalf of whom a Contribution has been received by Licensor and |
|||
subsequently incorporated within the Work. |
|||
|
|||
2. Grant of Copyright License. Subject to the terms and conditions of |
|||
this License, each Contributor hereby grants to You a perpetual, |
|||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable |
|||
copyright license to reproduce, prepare Derivative Works of, |
|||
publicly display, publicly perform, sublicense, and distribute the |
|||
Work and such Derivative Works in Source or Object form. |
|||
|
|||
3. Grant of Patent License. Subject to the terms and conditions of |
|||
this License, each Contributor hereby grants to You a perpetual, |
|||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable |
|||
(except as stated in this section) patent license to make, have made, |
|||
use, offer to sell, sell, import, and otherwise transfer the Work, |
|||
where such license applies only to those patent claims licensable |
|||
by such Contributor that are necessarily infringed by their |
|||
Contribution(s) alone or by combination of their Contribution(s) |
|||
with the Work to which such Contribution(s) was submitted. If You |
|||
institute patent litigation against any entity (including a |
|||
cross-claim or counterclaim in a lawsuit) alleging that the Work |
|||
or a Contribution incorporated within the Work constitutes direct |
|||
or contributory patent infringement, then any patent licenses |
|||
granted to You under this License for that Work shall terminate |
|||
as of the date such litigation is filed. |
|||
|
|||
4. Redistribution. You may reproduce and distribute copies of the |
|||
Work or Derivative Works thereof in any medium, with or without |
|||
modifications, and in Source or Object form, provided that You |
|||
meet the following conditions: |
|||
|
|||
(a) You must give any other recipients of the Work or |
|||
Derivative Works a copy of this License; and |
|||
|
|||
(b) You must cause any modified files to carry prominent notices |
|||
stating that You changed the files; and |
|||
|
|||
(c) You must retain, in the Source form of any Derivative Works |
|||
that You distribute, all copyright, patent, trademark, and |
|||
attribution notices from the Source form of the Work, |
|||
excluding those notices that do not pertain to any part of |
|||
the Derivative Works; and |
|||
|
|||
(d) If the Work includes a "NOTICE" text file as part of its |
|||
distribution, then any Derivative Works that You distribute must |
|||
include a readable copy of the attribution notices contained |
|||
within such NOTICE file, excluding those notices that do not |
|||
pertain to any part of the Derivative Works, in at least one |
|||
of the following places: within a NOTICE text file distributed |
|||
as part of the Derivative Works; within the Source form or |
|||
documentation, if provided along with the Derivative Works; or, |
|||
within a display generated by the Derivative Works, if and |
|||
wherever such third-party notices normally appear. The contents |
|||
of the NOTICE file are for informational purposes only and |
|||
do not modify the License. You may add Your own attribution |
|||
notices within Derivative Works that You distribute, alongside |
|||
or as an addendum to the NOTICE text from the Work, provided |
|||
that such additional attribution notices cannot be construed |
|||
as modifying the License. |
|||
|
|||
You may add Your own copyright statement to Your modifications and |
|||
may provide additional or different license terms and conditions |
|||
for use, reproduction, or distribution of Your modifications, or |
|||
for any such Derivative Works as a whole, provided Your use, |
|||
reproduction, and distribution of the Work otherwise complies with |
|||
the conditions stated in this License. |
|||
|
|||
5. Submission of Contributions. Unless You explicitly state otherwise, |
|||
any Contribution intentionally submitted for inclusion in the Work |
|||
by You to the Licensor shall be under the terms and conditions of |
|||
this License, without any additional terms or conditions. |
|||
Notwithstanding the above, nothing herein shall supersede or modify |
|||
the terms of any separate license agreement you may have executed |
|||
with Licensor regarding such Contributions. |
|||
|
|||
6. Trademarks. This License does not grant permission to use the trade |
|||
names, trademarks, service marks, or product names of the Licensor, |
|||
except as required for reasonable and customary use in describing the |
|||
origin of the Work and reproducing the content of the NOTICE file. |
|||
|
|||
7. Disclaimer of Warranty. Unless required by applicable law or |
|||
agreed to in writing, Licensor provides the Work (and each |
|||
Contributor provides its Contributions) on an "AS IS" BASIS, |
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|||
implied, including, without limitation, any warranties or conditions |
|||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A |
|||
PARTICULAR PURPOSE. You are solely responsible for determining the |
|||
appropriateness of using or redistributing the Work and assume any |
|||
risks associated with Your exercise of permissions under this License. |
|||
|
|||
8. Limitation of Liability. In no event and under no legal theory, |
|||
whether in tort (including negligence), contract, or otherwise, |
|||
unless required by applicable law (such as deliberate and grossly |
|||
negligent acts) or agreed to in writing, shall any Contributor be |
|||
liable to You for damages, including any direct, indirect, special, |
|||
incidental, or consequential damages of any character arising as a |
|||
result of this License or out of the use or inability to use the |
|||
Work (including but not limited to damages for loss of goodwill, |
|||
work stoppage, computer failure or malfunction, or any and all |
|||
other commercial damages or losses), even if such Contributor |
|||
has been advised of the possibility of such damages. |
|||
|
|||
9. Accepting Warranty or Additional Liability. While redistributing |
|||
the Work or Derivative Works thereof, You may choose to offer, |
|||
and charge a fee for, acceptance of support, warranty, indemnity, |
|||
or other liability obligations and/or rights consistent with this |
|||
License. However, in accepting such obligations, You may act only |
|||
on Your own behalf and on Your sole responsibility, not on behalf |
|||
of any other Contributor, and only if You agree to indemnify, |
|||
defend, and hold each Contributor harmless for any liability |
|||
incurred by, or claims asserted against, such Contributor by reason |
|||
of your accepting any such warranty or additional liability. |
|||
|
|||
END OF TERMS AND CONDITIONS |
|||
|
|||
APPENDIX: How to apply the Apache License to your work. |
|||
|
|||
To apply the Apache License to your work, attach the following |
|||
boilerplate notice, with the fields enclosed by brackets "[]" |
|||
replaced with your own identifying information. (Don't include |
|||
the brackets!) The text should be enclosed in the appropriate |
|||
comment syntax for the file format. We also recommend that a |
|||
file or class name and description of purpose be included on the |
|||
same "printed page" as the copyright notice for easier |
|||
identification within third-party archives. |
|||
|
|||
Copyright [yyyy] [name of copyright owner] |
|||
|
|||
Licensed under the Apache License, Version 2.0 (the "License"); |
|||
you may not use this file except in compliance with the License. |
|||
You may obtain a copy of the License at |
|||
|
|||
http://www.apache.org/licenses/LICENSE-2.0 |
|||
|
|||
Unless required by applicable law or agreed to in writing, software |
|||
distributed under the License is distributed on an "AS IS" BASIS, |
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
See the License for the specific language governing permissions and |
|||
limitations under the License. |
@ -0,0 +1,3 @@ |
|||
AWS SDK for Go |
|||
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
|||
Copyright 2014-2015 Stripe, Inc. |
@ -0,0 +1,145 @@ |
|||
// Package awserr represents API error interface accessors for the SDK.
|
|||
package awserr |
|||
|
|||
// An Error wraps lower level errors with code, message and an original error.
|
|||
// The underlying concrete error type may also satisfy other interfaces which
|
|||
// can be to used to obtain more specific information about the error.
|
|||
//
|
|||
// Calling Error() or String() will always include the full information about
|
|||
// an error based on its underlying type.
|
|||
//
|
|||
// Example:
|
|||
//
|
|||
// output, err := s3manage.Upload(svc, input, opts)
|
|||
// if err != nil {
|
|||
// if awsErr, ok := err.(awserr.Error); ok {
|
|||
// // Get error details
|
|||
// log.Println("Error:", awsErr.Code(), awsErr.Message())
|
|||
//
|
|||
// // Prints out full error message, including original error if there was one.
|
|||
// log.Println("Error:", awsErr.Error())
|
|||
//
|
|||
// // Get original error
|
|||
// if origErr := awsErr.OrigErr(); origErr != nil {
|
|||
// // operate on original error.
|
|||
// }
|
|||
// } else {
|
|||
// fmt.Println(err.Error())
|
|||
// }
|
|||
// }
|
|||
//
|
|||
type Error interface { |
|||
// Satisfy the generic error interface.
|
|||
error |
|||
|
|||
// Returns the short phrase depicting the classification of the error.
|
|||
Code() string |
|||
|
|||
// Returns the error details message.
|
|||
Message() string |
|||
|
|||
// Returns the original error if one was set. Nil is returned if not set.
|
|||
OrigErr() error |
|||
} |
|||
|
|||
// BatchError is a batch of errors which also wraps lower level errors with
|
|||
// code, message, and original errors. Calling Error() will include all errors
|
|||
// that occurred in the batch.
|
|||
//
|
|||
// Deprecated: Replaced with BatchedErrors. Only defined for backwards
|
|||
// compatibility.
|
|||
type BatchError interface { |
|||
// Satisfy the generic error interface.
|
|||
error |
|||
|
|||
// Returns the short phrase depicting the classification of the error.
|
|||
Code() string |
|||
|
|||
// Returns the error details message.
|
|||
Message() string |
|||
|
|||
// Returns the original error if one was set. Nil is returned if not set.
|
|||
OrigErrs() []error |
|||
} |
|||
|
|||
// BatchedErrors is a batch of errors which also wraps lower level errors with
|
|||
// code, message, and original errors. Calling Error() will include all errors
|
|||
// that occurred in the batch.
|
|||
//
|
|||
// Replaces BatchError
|
|||
type BatchedErrors interface { |
|||
// Satisfy the base Error interface.
|
|||
Error |
|||
|
|||
// Returns the original error if one was set. Nil is returned if not set.
|
|||
OrigErrs() []error |
|||
} |
|||
|
|||
// New returns an Error object described by the code, message, and origErr.
|
|||
//
|
|||
// If origErr satisfies the Error interface it will not be wrapped within a new
|
|||
// Error object and will instead be returned.
|
|||
func New(code, message string, origErr error) Error { |
|||
var errs []error |
|||
if origErr != nil { |
|||
errs = append(errs, origErr) |
|||
} |
|||
return newBaseError(code, message, errs) |
|||
} |
|||
|
|||
// NewBatchError returns an BatchedErrors with a collection of errors as an
|
|||
// array of errors.
|
|||
func NewBatchError(code, message string, errs []error) BatchedErrors { |
|||
return newBaseError(code, message, errs) |
|||
} |
|||
|
|||
// A RequestFailure is an interface to extract request failure information from
|
|||
// an Error such as the request ID of the failed request returned by a service.
|
|||
// RequestFailures may not always have a requestID value if the request failed
|
|||
// prior to reaching the service such as a connection error.
|
|||
//
|
|||
// Example:
|
|||
//
|
|||
// output, err := s3manage.Upload(svc, input, opts)
|
|||
// if err != nil {
|
|||
// if reqerr, ok := err.(RequestFailure); ok {
|
|||
// log.Println("Request failed", reqerr.Code(), reqerr.Message(), reqerr.RequestID())
|
|||
// } else {
|
|||
// log.Println("Error:", err.Error())
|
|||
// }
|
|||
// }
|
|||
//
|
|||
// Combined with awserr.Error:
|
|||
//
|
|||
// output, err := s3manage.Upload(svc, input, opts)
|
|||
// if err != nil {
|
|||
// if awsErr, ok := err.(awserr.Error); ok {
|
|||
// // Generic AWS Error with Code, Message, and original error (if any)
|
|||
// fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
|
|||
//
|
|||
// if reqErr, ok := err.(awserr.RequestFailure); ok {
|
|||
// // A service error occurred
|
|||
// fmt.Println(reqErr.StatusCode(), reqErr.RequestID())
|
|||
// }
|
|||
// } else {
|
|||
// fmt.Println(err.Error())
|
|||
// }
|
|||
// }
|
|||
//
|
|||
type RequestFailure interface { |
|||
Error |
|||
|
|||
// The status code of the HTTP response.
|
|||
StatusCode() int |
|||
|
|||
// The request ID returned by the service for a request failure. This will
|
|||
// be empty if no request ID is available such as the request failed due
|
|||
// to a connection error.
|
|||
RequestID() string |
|||
} |
|||
|
|||
// NewRequestFailure returns a new request error wrapper for the given Error
|
|||
// provided.
|
|||
func NewRequestFailure(err Error, statusCode int, reqID string) RequestFailure { |
|||
return newRequestError(err, statusCode, reqID) |
|||
} |
@ -0,0 +1,194 @@ |
|||
package awserr |
|||
|
|||
import "fmt" |
|||
|
|||
// SprintError returns a string of the formatted error code.
|
|||
//
|
|||
// Both extra and origErr are optional. If they are included their lines
|
|||
// will be added, but if they are not included their lines will be ignored.
|
|||
func SprintError(code, message, extra string, origErr error) string { |
|||
msg := fmt.Sprintf("%s: %s", code, message) |
|||
if extra != "" { |
|||
msg = fmt.Sprintf("%s\n\t%s", msg, extra) |
|||
} |
|||
if origErr != nil { |
|||
msg = fmt.Sprintf("%s\ncaused by: %s", msg, origErr.Error()) |
|||
} |
|||
return msg |
|||
} |
|||
|
|||
// A baseError wraps the code and message which defines an error. It also
|
|||
// can be used to wrap an original error object.
|
|||
//
|
|||
// Should be used as the root for errors satisfying the awserr.Error. Also
|
|||
// for any error which does not fit into a specific error wrapper type.
|
|||
type baseError struct { |
|||
// Classification of error
|
|||
code string |
|||
|
|||
// Detailed information about error
|
|||
message string |
|||
|
|||
// Optional original error this error is based off of. Allows building
|
|||
// chained errors.
|
|||
errs []error |
|||
} |
|||
|
|||
// newBaseError returns an error object for the code, message, and errors.
|
|||
//
|
|||
// code is a short no whitespace phrase depicting the classification of
|
|||
// the error that is being created.
|
|||
//
|
|||
// message is the free flow string containing detailed information about the
|
|||
// error.
|
|||
//
|
|||
// origErrs is the error objects which will be nested under the new errors to
|
|||
// be returned.
|
|||
func newBaseError(code, message string, origErrs []error) *baseError { |
|||
b := &baseError{ |
|||
code: code, |
|||
message: message, |
|||
errs: origErrs, |
|||
} |
|||
|
|||
return b |
|||
} |
|||
|
|||
// Error returns the string representation of the error.
|
|||
//
|
|||
// See ErrorWithExtra for formatting.
|
|||
//
|
|||
// Satisfies the error interface.
|
|||
func (b baseError) Error() string { |
|||
size := len(b.errs) |
|||
if size > 0 { |
|||
return SprintError(b.code, b.message, "", errorList(b.errs)) |
|||
} |
|||
|
|||
return SprintError(b.code, b.message, "", nil) |
|||
} |
|||
|
|||
// String returns the string representation of the error.
|
|||
// Alias for Error to satisfy the stringer interface.
|
|||
func (b baseError) String() string { |
|||
return b.Error() |
|||
} |
|||
|
|||
// Code returns the short phrase depicting the classification of the error.
|
|||
func (b baseError) Code() string { |
|||
return b.code |
|||
} |
|||
|
|||
// Message returns the error details message.
|
|||
func (b baseError) Message() string { |
|||
return b.message |
|||
} |
|||
|
|||
// OrigErr returns the original error if one was set. Nil is returned if no
|
|||
// error was set. This only returns the first element in the list. If the full
|
|||
// list is needed, use BatchedErrors.
|
|||
func (b baseError) OrigErr() error { |
|||
switch len(b.errs) { |
|||
case 0: |
|||
return nil |
|||
case 1: |
|||
return b.errs[0] |
|||
default: |
|||
if err, ok := b.errs[0].(Error); ok { |
|||
return NewBatchError(err.Code(), err.Message(), b.errs[1:]) |
|||
} |
|||
return NewBatchError("BatchedErrors", |
|||
"multiple errors occurred", b.errs) |
|||
} |
|||
} |
|||
|
|||
// OrigErrs returns the original errors if one was set. An empty slice is
|
|||
// returned if no error was set.
|
|||
func (b baseError) OrigErrs() []error { |
|||
return b.errs |
|||
} |
|||
|
|||
// So that the Error interface type can be included as an anonymous field
|
|||
// in the requestError struct and not conflict with the error.Error() method.
|
|||
type awsError Error |
|||
|
|||
// A requestError wraps a request or service error.
|
|||
//
|
|||
// Composed of baseError for code, message, and original error.
|
|||
type requestError struct { |
|||
awsError |
|||
statusCode int |
|||
requestID string |
|||
} |
|||
|
|||
// newRequestError returns a wrapped error with additional information for
|
|||
// request status code, and service requestID.
|
|||
//
|
|||
// Should be used to wrap all request which involve service requests. Even if
|
|||
// the request failed without a service response, but had an HTTP status code
|
|||
// that may be meaningful.
|
|||
//
|
|||
// Also wraps original errors via the baseError.
|
|||
func newRequestError(err Error, statusCode int, requestID string) *requestError { |
|||
return &requestError{ |
|||
awsError: err, |
|||
statusCode: statusCode, |
|||
requestID: requestID, |
|||
} |
|||
} |
|||
|
|||
// Error returns the string representation of the error.
|
|||
// Satisfies the error interface.
|
|||
func (r requestError) Error() string { |
|||
extra := fmt.Sprintf("status code: %d, request id: %s", |
|||
r.statusCode, r.requestID) |
|||
return SprintError(r.Code(), r.Message(), extra, r.OrigErr()) |
|||
} |
|||
|
|||
// String returns the string representation of the error.
|
|||
// Alias for Error to satisfy the stringer interface.
|
|||
func (r requestError) String() string { |
|||
return r.Error() |
|||
} |
|||
|
|||
// StatusCode returns the wrapped status code for the error
|
|||
func (r requestError) StatusCode() int { |
|||
return r.statusCode |
|||
} |
|||
|
|||
// RequestID returns the wrapped requestID
|
|||
func (r requestError) RequestID() string { |
|||
return r.requestID |
|||
} |
|||
|
|||
// OrigErrs returns the original errors if one was set. An empty slice is
|
|||
// returned if no error was set.
|
|||
func (r requestError) OrigErrs() []error { |
|||
if b, ok := r.awsError.(BatchedErrors); ok { |
|||
return b.OrigErrs() |
|||
} |
|||
return []error{r.OrigErr()} |
|||
} |
|||
|
|||
// An error list that satisfies the golang interface
|
|||
type errorList []error |
|||
|
|||
// Error returns the string representation of the error.
|
|||
//
|
|||
// Satisfies the error interface.
|
|||
func (e errorList) Error() string { |
|||
msg := "" |
|||
// How do we want to handle the array size being zero
|
|||
if size := len(e); size > 0 { |
|||
for i := 0; i < size; i++ { |
|||
msg += fmt.Sprintf("%s", e[i].Error()) |
|||
// We check the next index to see if it is within the slice.
|
|||
// If it is, then we append a newline. We do this, because unit tests
|
|||
// could be broken with the additional '\n'
|
|||
if i+1 < size { |
|||
msg += "\n" |
|||
} |
|||
} |
|||
} |
|||
return msg |
|||
} |
@ -0,0 +1,108 @@ |
|||
package awsutil |
|||
|
|||
import ( |
|||
"io" |
|||
"reflect" |
|||
"time" |
|||
) |
|||
|
|||
// Copy deeply copies a src structure to dst. Useful for copying request and
|
|||
// response structures.
|
|||
//
|
|||
// Can copy between structs of different type, but will only copy fields which
|
|||
// are assignable, and exist in both structs. Fields which are not assignable,
|
|||
// or do not exist in both structs are ignored.
|
|||
func Copy(dst, src interface{}) { |
|||
dstval := reflect.ValueOf(dst) |
|||
if !dstval.IsValid() { |
|||
panic("Copy dst cannot be nil") |
|||
} |
|||
|
|||
rcopy(dstval, reflect.ValueOf(src), true) |
|||
} |
|||
|
|||
// CopyOf returns a copy of src while also allocating the memory for dst.
|
|||
// src must be a pointer type or this operation will fail.
|
|||
func CopyOf(src interface{}) (dst interface{}) { |
|||
dsti := reflect.New(reflect.TypeOf(src).Elem()) |
|||
dst = dsti.Interface() |
|||
rcopy(dsti, reflect.ValueOf(src), true) |
|||
return |
|||
} |
|||
|
|||
// rcopy performs a recursive copy of values from the source to destination.
|
|||
//
|
|||
// root is used to skip certain aspects of the copy which are not valid
|
|||
// for the root node of a object.
|
|||
func rcopy(dst, src reflect.Value, root bool) { |
|||
if !src.IsValid() { |
|||
return |
|||
} |
|||
|
|||
switch src.Kind() { |
|||
case reflect.Ptr: |
|||
if _, ok := src.Interface().(io.Reader); ok { |
|||
if dst.Kind() == reflect.Ptr && dst.Elem().CanSet() { |
|||
dst.Elem().Set(src) |
|||
} else if dst.CanSet() { |
|||
dst.Set(src) |
|||
} |
|||
} else { |
|||
e := src.Type().Elem() |
|||
if dst.CanSet() && !src.IsNil() { |
|||
if _, ok := src.Interface().(*time.Time); !ok { |
|||
dst.Set(reflect.New(e)) |
|||
} else { |
|||
tempValue := reflect.New(e) |
|||
tempValue.Elem().Set(src.Elem()) |
|||
// Sets time.Time's unexported values
|
|||
dst.Set(tempValue) |
|||
} |
|||
} |
|||
if src.Elem().IsValid() { |
|||
// Keep the current root state since the depth hasn't changed
|
|||
rcopy(dst.Elem(), src.Elem(), root) |
|||
} |
|||
} |
|||
case reflect.Struct: |
|||
t := dst.Type() |
|||
for i := 0; i < t.NumField(); i++ { |
|||
name := t.Field(i).Name |
|||
srcVal := src.FieldByName(name) |
|||
dstVal := dst.FieldByName(name) |
|||
if srcVal.IsValid() && dstVal.CanSet() { |
|||
rcopy(dstVal, srcVal, false) |
|||
} |
|||
} |
|||
case reflect.Slice: |
|||
if src.IsNil() { |
|||
break |
|||
} |
|||
|
|||
s := reflect.MakeSlice(src.Type(), src.Len(), src.Cap()) |
|||
dst.Set(s) |
|||
for i := 0; i < src.Len(); i++ { |
|||
rcopy(dst.Index(i), src.Index(i), false) |
|||
} |
|||
case reflect.Map: |
|||
if src.IsNil() { |
|||
break |
|||
} |
|||
|
|||
s := reflect.MakeMap(src.Type()) |
|||
dst.Set(s) |
|||
for _, k := range src.MapKeys() { |
|||
v := src.MapIndex(k) |
|||
v2 := reflect.New(v.Type()).Elem() |
|||
rcopy(v2, v, false) |
|||
dst.SetMapIndex(k, v2) |
|||
} |
|||
default: |
|||
// Assign the value if possible. If its not assignable, the value would
|
|||
// need to be converted and the impact of that may be unexpected, or is
|
|||
// not compatible with the dst type.
|
|||
if src.Type().AssignableTo(dst.Type()) { |
|||
dst.Set(src) |
|||
} |
|||
} |
|||
} |
@ -0,0 +1,27 @@ |
|||
package awsutil |
|||
|
|||
import ( |
|||
"reflect" |
|||
) |
|||
|
|||
// DeepEqual returns if the two values are deeply equal like reflect.DeepEqual.
|
|||
// In addition to this, this method will also dereference the input values if
|
|||
// possible so the DeepEqual performed will not fail if one parameter is a
|
|||
// pointer and the other is not.
|
|||
//
|
|||
// DeepEqual will not perform indirection of nested values of the input parameters.
|
|||
func DeepEqual(a, b interface{}) bool { |
|||
ra := reflect.Indirect(reflect.ValueOf(a)) |
|||
rb := reflect.Indirect(reflect.ValueOf(b)) |
|||
|
|||
if raValid, rbValid := ra.IsValid(), rb.IsValid(); !raValid && !rbValid { |
|||
// If the elements are both nil, and of the same type the are equal
|
|||
// If they are of different types they are not equal
|
|||
return reflect.TypeOf(a) == reflect.TypeOf(b) |
|||
} else if raValid != rbValid { |
|||
// Both values must be valid to be equal
|
|||
return false |
|||
} |
|||
|
|||
return reflect.DeepEqual(ra.Interface(), rb.Interface()) |
|||
} |
@ -0,0 +1,222 @@ |
|||
package awsutil |
|||
|
|||
import ( |
|||
"reflect" |
|||
"regexp" |
|||
"strconv" |
|||
"strings" |
|||
|
|||
"github.com/jmespath/go-jmespath" |
|||
) |
|||
|
|||
var indexRe = regexp.MustCompile(`(.+)\[(-?\d+)?\]$`) |
|||
|
|||
// rValuesAtPath returns a slice of values found in value v. The values
|
|||
// in v are explored recursively so all nested values are collected.
|
|||
func rValuesAtPath(v interface{}, path string, createPath, caseSensitive, nilTerm bool) []reflect.Value { |
|||
pathparts := strings.Split(path, "||") |
|||
if len(pathparts) > 1 { |
|||
for _, pathpart := range pathparts { |
|||
vals := rValuesAtPath(v, pathpart, createPath, caseSensitive, nilTerm) |
|||
if len(vals) > 0 { |
|||
return vals |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
values := []reflect.Value{reflect.Indirect(reflect.ValueOf(v))} |
|||
components := strings.Split(path, ".") |
|||
for len(values) > 0 && len(components) > 0 { |
|||
var index *int64 |
|||
var indexStar bool |
|||
c := strings.TrimSpace(components[0]) |
|||
if c == "" { // no actual component, illegal syntax
|
|||
return nil |
|||
} else if caseSensitive && c != "*" && strings.ToLower(c[0:1]) == c[0:1] { |
|||
// TODO normalize case for user
|
|||
return nil // don't support unexported fields
|
|||
} |
|||
|
|||
// parse this component
|
|||
if m := indexRe.FindStringSubmatch(c); m != nil { |
|||
c = m[1] |
|||
if m[2] == "" { |
|||
index = nil |
|||
indexStar = true |
|||
} else { |
|||
i, _ := strconv.ParseInt(m[2], 10, 32) |
|||
index = &i |
|||
indexStar = false |
|||
} |
|||
} |
|||
|
|||
nextvals := []reflect.Value{} |
|||
for _, value := range values { |
|||
// pull component name out of struct member
|
|||
if value.Kind() != reflect.Struct { |
|||
continue |
|||
} |
|||
|
|||
if c == "*" { // pull all members
|
|||
for i := 0; i < value.NumField(); i++ { |
|||
if f := reflect.Indirect(value.Field(i)); f.IsValid() { |
|||
nextvals = append(nextvals, f) |
|||
} |
|||
} |
|||
continue |
|||
} |
|||
|
|||
value = value.FieldByNameFunc(func(name string) bool { |
|||
if c == name { |
|||
return true |
|||
} else if !caseSensitive && strings.ToLower(name) == strings.ToLower(c) { |
|||
return true |
|||
} |
|||
return false |
|||
}) |
|||
|
|||
if nilTerm && value.Kind() == reflect.Ptr && len(components[1:]) == 0 { |
|||
if !value.IsNil() { |
|||
value.Set(reflect.Zero(value.Type())) |
|||
} |
|||
return []reflect.Value{value} |
|||
} |
|||
|
|||
if createPath && value.Kind() == reflect.Ptr && value.IsNil() { |
|||
// TODO if the value is the terminus it should not be created
|
|||
// if the value to be set to its position is nil.
|
|||
value.Set(reflect.New(value.Type().Elem())) |
|||
value = value.Elem() |
|||
} else { |
|||
value = reflect.Indirect(value) |
|||
} |
|||
|
|||
if value.Kind() == reflect.Slice || value.Kind() == reflect.Map { |
|||
if !createPath && value.IsNil() { |
|||
value = reflect.ValueOf(nil) |
|||
} |
|||
} |
|||
|
|||
if value.IsValid() { |
|||
nextvals = append(nextvals, value) |
|||
} |
|||
} |
|||
values = nextvals |
|||
|
|||
if indexStar || index != nil { |
|||
nextvals = []reflect.Value{} |
|||
for _, valItem := range values { |
|||
value := reflect.Indirect(valItem) |
|||
if value.Kind() != reflect.Slice { |
|||
continue |
|||
} |
|||
|
|||
if indexStar { // grab all indices
|
|||
for i := 0; i < value.Len(); i++ { |
|||
idx := reflect.Indirect(value.Index(i)) |
|||
if idx.IsValid() { |
|||
nextvals = append(nextvals, idx) |
|||
} |
|||
} |
|||
continue |
|||
} |
|||
|
|||
// pull out index
|
|||
i := int(*index) |
|||
if i >= value.Len() { // check out of bounds
|
|||
if createPath { |
|||
// TODO resize slice
|
|||
} else { |
|||
continue |
|||
} |
|||
} else if i < 0 { // support negative indexing
|
|||
i = value.Len() + i |
|||
} |
|||
value = reflect.Indirect(value.Index(i)) |
|||
|
|||
if value.Kind() == reflect.Slice || value.Kind() == reflect.Map { |
|||
if !createPath && value.IsNil() { |
|||
value = reflect.ValueOf(nil) |
|||
} |
|||
} |
|||
|
|||
if value.IsValid() { |
|||
nextvals = append(nextvals, value) |
|||
} |
|||
} |
|||
values = nextvals |
|||
} |
|||
|
|||
components = components[1:] |
|||
} |
|||
return values |
|||
} |
|||
|
|||
// ValuesAtPath returns a list of values at the case insensitive lexical
|
|||
// path inside of a structure.
|
|||
func ValuesAtPath(i interface{}, path string) ([]interface{}, error) { |
|||
result, err := jmespath.Search(path, i) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
v := reflect.ValueOf(result) |
|||
if !v.IsValid() || (v.Kind() == reflect.Ptr && v.IsNil()) { |
|||
return nil, nil |
|||
} |
|||
if s, ok := result.([]interface{}); ok { |
|||
return s, err |
|||
} |
|||
if v.Kind() == reflect.Map && v.Len() == 0 { |
|||
return nil, nil |
|||
} |
|||
if v.Kind() == reflect.Slice { |
|||
out := make([]interface{}, v.Len()) |
|||
for i := 0; i < v.Len(); i++ { |
|||
out[i] = v.Index(i).Interface() |
|||
} |
|||
return out, nil |
|||
} |
|||
|
|||
return []interface{}{result}, nil |
|||
} |
|||
|
|||
// SetValueAtPath sets a value at the case insensitive lexical path inside
|
|||
// of a structure.
|
|||
func SetValueAtPath(i interface{}, path string, v interface{}) { |
|||
if rvals := rValuesAtPath(i, path, true, false, v == nil); rvals != nil { |
|||
for _, rval := range rvals { |
|||
if rval.Kind() == reflect.Ptr && rval.IsNil() { |
|||
continue |
|||
} |
|||
setValue(rval, v) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func setValue(dstVal reflect.Value, src interface{}) { |
|||
if dstVal.Kind() == reflect.Ptr { |
|||
dstVal = reflect.Indirect(dstVal) |
|||
} |
|||
srcVal := reflect.ValueOf(src) |
|||
|
|||
if !srcVal.IsValid() { // src is literal nil
|
|||
if dstVal.CanAddr() { |
|||
// Convert to pointer so that pointer's value can be nil'ed
|
|||
// dstVal = dstVal.Addr()
|
|||
} |
|||
dstVal.Set(reflect.Zero(dstVal.Type())) |
|||
|
|||
} else if srcVal.Kind() == reflect.Ptr { |
|||
if srcVal.IsNil() { |
|||
srcVal = reflect.Zero(dstVal.Type()) |
|||
} else { |
|||
srcVal = reflect.ValueOf(src).Elem() |
|||
} |
|||
dstVal.Set(srcVal) |
|||
} else { |
|||
dstVal.Set(srcVal) |
|||
} |
|||
|
|||
} |
@ -0,0 +1,107 @@ |
|||
package awsutil |
|||
|
|||
import ( |
|||
"bytes" |
|||
"fmt" |
|||
"io" |
|||
"reflect" |
|||
"strings" |
|||
) |
|||
|
|||
// Prettify returns the string representation of a value.
|
|||
func Prettify(i interface{}) string { |
|||
var buf bytes.Buffer |
|||
prettify(reflect.ValueOf(i), 0, &buf) |
|||
return buf.String() |
|||
} |
|||
|
|||
// prettify will recursively walk value v to build a textual
|
|||
// representation of the value.
|
|||
func prettify(v reflect.Value, indent int, buf *bytes.Buffer) { |
|||
for v.Kind() == reflect.Ptr { |
|||
v = v.Elem() |
|||
} |
|||
|
|||
switch v.Kind() { |
|||
case reflect.Struct: |
|||
strtype := v.Type().String() |
|||
if strtype == "time.Time" { |
|||
fmt.Fprintf(buf, "%s", v.Interface()) |
|||
break |
|||
} else if strings.HasPrefix(strtype, "io.") { |
|||
buf.WriteString("<buffer>") |
|||
break |
|||
} |
|||
|
|||
buf.WriteString("{\n") |
|||
|
|||
names := []string{} |
|||
for i := 0; i < v.Type().NumField(); i++ { |
|||
name := v.Type().Field(i).Name |
|||
f := v.Field(i) |
|||
if name[0:1] == strings.ToLower(name[0:1]) { |
|||
continue // ignore unexported fields
|
|||
} |
|||
if (f.Kind() == reflect.Ptr || f.Kind() == reflect.Slice || f.Kind() == reflect.Map) && f.IsNil() { |
|||
continue // ignore unset fields
|
|||
} |
|||
names = append(names, name) |
|||
} |
|||
|
|||
for i, n := range names { |
|||
val := v.FieldByName(n) |
|||
buf.WriteString(strings.Repeat(" ", indent+2)) |
|||
buf.WriteString(n + ": ") |
|||
prettify(val, indent+2, buf) |
|||
|
|||
if i < len(names)-1 { |
|||
buf.WriteString(",\n") |
|||
} |
|||
} |
|||
|
|||
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}") |
|||
case reflect.Slice: |
|||
nl, id, id2 := "", "", "" |
|||
if v.Len() > 3 { |
|||
nl, id, id2 = "\n", strings.Repeat(" ", indent), strings.Repeat(" ", indent+2) |
|||
} |
|||
buf.WriteString("[" + nl) |
|||
for i := 0; i < v.Len(); i++ { |
|||
buf.WriteString(id2) |
|||
prettify(v.Index(i), indent+2, buf) |
|||
|
|||
if i < v.Len()-1 { |
|||
buf.WriteString("," + nl) |
|||
} |
|||
} |
|||
|
|||
buf.WriteString(nl + id + "]") |
|||
case reflect.Map: |
|||
buf.WriteString("{\n") |
|||
|
|||
for i, k := range v.MapKeys() { |
|||
buf.WriteString(strings.Repeat(" ", indent+2)) |
|||
buf.WriteString(k.String() + ": ") |
|||
prettify(v.MapIndex(k), indent+2, buf) |
|||
|
|||
if i < v.Len()-1 { |
|||
buf.WriteString(",\n") |
|||
} |
|||
} |
|||
|
|||
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}") |
|||
default: |
|||
if !v.IsValid() { |
|||
fmt.Fprint(buf, "<invalid value>") |
|||
return |
|||
} |
|||
format := "%v" |
|||
switch v.Interface().(type) { |
|||
case string: |
|||
format = "%q" |
|||
case io.ReadSeeker, io.Reader: |
|||
format = "buffer(%p)" |
|||
} |
|||
fmt.Fprintf(buf, format, v.Interface()) |
|||
} |
|||
} |
@ -0,0 +1,89 @@ |
|||
package awsutil |
|||
|
|||
import ( |
|||
"bytes" |
|||
"fmt" |
|||
"reflect" |
|||
"strings" |
|||
) |
|||
|
|||
// StringValue returns the string representation of a value.
|
|||
func StringValue(i interface{}) string { |
|||
var buf bytes.Buffer |
|||
stringValue(reflect.ValueOf(i), 0, &buf) |
|||
return buf.String() |
|||
} |
|||
|
|||
func stringValue(v reflect.Value, indent int, buf *bytes.Buffer) { |
|||
for v.Kind() == reflect.Ptr { |
|||
v = v.Elem() |
|||
} |
|||
|
|||
switch v.Kind() { |
|||
case reflect.Struct: |
|||
buf.WriteString("{\n") |
|||
|
|||
names := []string{} |
|||
for i := 0; i < v.Type().NumField(); i++ { |
|||
name := v.Type().Field(i).Name |
|||
f := v.Field(i) |
|||
if name[0:1] == strings.ToLower(name[0:1]) { |
|||
continue // ignore unexported fields
|
|||
} |
|||
if (f.Kind() == reflect.Ptr || f.Kind() == reflect.Slice) && f.IsNil() { |
|||
continue // ignore unset fields
|
|||
} |
|||
names = append(names, name) |
|||
} |
|||
|
|||
for i, n := range names { |
|||
val := v.FieldByName(n) |
|||
buf.WriteString(strings.Repeat(" ", indent+2)) |
|||
buf.WriteString(n + ": ") |
|||
stringValue(val, indent+2, buf) |
|||
|
|||
if i < len(names)-1 { |
|||
buf.WriteString(",\n") |
|||
} |
|||
} |
|||
|
|||
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}") |
|||
case reflect.Slice: |
|||
nl, id, id2 := "", "", "" |
|||
if v.Len() > 3 { |
|||
nl, id, id2 = "\n", strings.Repeat(" ", indent), strings.Repeat(" ", indent+2) |
|||
} |
|||
buf.WriteString("[" + nl) |
|||
for i := 0; i < v.Len(); i++ { |
|||
buf.WriteString(id2) |
|||
stringValue(v.Index(i), indent+2, buf) |
|||
|
|||
if i < v.Len()-1 { |
|||
buf.WriteString("," + nl) |
|||
} |
|||
} |
|||
|
|||
buf.WriteString(nl + id + "]") |
|||
case reflect.Map: |
|||
buf.WriteString("{\n") |
|||
|
|||
for i, k := range v.MapKeys() { |
|||
buf.WriteString(strings.Repeat(" ", indent+2)) |
|||
buf.WriteString(k.String() + ": ") |
|||
stringValue(v.MapIndex(k), indent+2, buf) |
|||
|
|||
if i < v.Len()-1 { |
|||
buf.WriteString(",\n") |
|||
} |
|||
} |
|||
|
|||
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}") |
|||
default: |
|||
format := "%v" |
|||
switch v.Interface().(type) { |
|||
case string: |
|||
format = "%q" |
|||
} |
|||
fmt.Fprintf(buf, format, v.Interface()) |
|||
} |
|||
} |
@ -0,0 +1,137 @@ |
|||
package client |
|||
|
|||
import ( |
|||
"fmt" |
|||
"net/http/httputil" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws" |
|||
"github.com/aws/aws-sdk-go/aws/client/metadata" |
|||
"github.com/aws/aws-sdk-go/aws/request" |
|||
) |
|||
|
|||
// A Config provides configuration to a service client instance.
|
|||
type Config struct { |
|||
Config *aws.Config |
|||
Handlers request.Handlers |
|||
Endpoint, SigningRegion string |
|||
} |
|||
|
|||
// ConfigProvider provides a generic way for a service client to receive
|
|||
// the ClientConfig without circular dependencies.
|
|||
type ConfigProvider interface { |
|||
ClientConfig(serviceName string, cfgs ...*aws.Config) Config |
|||
} |
|||
|
|||
// A Client implements the base client request and response handling
|
|||
// used by all service clients.
|
|||
type Client struct { |
|||
request.Retryer |
|||
metadata.ClientInfo |
|||
|
|||
Config aws.Config |
|||
Handlers request.Handlers |
|||
} |
|||
|
|||
// New will return a pointer to a new initialized service client.
|
|||
func New(cfg aws.Config, info metadata.ClientInfo, handlers request.Handlers, options ...func(*Client)) *Client { |
|||
svc := &Client{ |
|||
Config: cfg, |
|||
ClientInfo: info, |
|||
Handlers: handlers, |
|||
} |
|||
|
|||
switch retryer, ok := cfg.Retryer.(request.Retryer); { |
|||
case ok: |
|||
svc.Retryer = retryer |
|||
case cfg.Retryer != nil && cfg.Logger != nil: |
|||
s := fmt.Sprintf("WARNING: %T does not implement request.Retryer; using DefaultRetryer instead", cfg.Retryer) |
|||
cfg.Logger.Log(s) |
|||
fallthrough |
|||
default: |
|||
maxRetries := aws.IntValue(cfg.MaxRetries) |
|||
if cfg.MaxRetries == nil || maxRetries == aws.UseServiceDefaultRetries { |
|||
maxRetries = 3 |
|||
} |
|||
svc.Retryer = DefaultRetryer{NumMaxRetries: maxRetries} |
|||
} |
|||
|
|||
svc.AddDebugHandlers() |
|||
|
|||
for _, option := range options { |
|||
option(svc) |
|||
} |
|||
|
|||
return svc |
|||
} |
|||
|
|||
// NewRequest returns a new Request pointer for the service API
|
|||
// operation and parameters.
|
|||
func (c *Client) NewRequest(operation *request.Operation, params interface{}, data interface{}) *request.Request { |
|||
return request.New(c.Config, c.ClientInfo, c.Handlers, c.Retryer, operation, params, data) |
|||
} |
|||
|
|||
// AddDebugHandlers injects debug logging handlers into the service to log request
|
|||
// debug information.
|
|||
func (c *Client) AddDebugHandlers() { |
|||
if !c.Config.LogLevel.AtLeast(aws.LogDebug) { |
|||
return |
|||
} |
|||
|
|||
c.Handlers.Send.PushFront(logRequest) |
|||
c.Handlers.Send.PushBack(logResponse) |
|||
} |
|||
|
|||
const logReqMsg = `DEBUG: Request %s/%s Details: |
|||
---[ REQUEST POST-SIGN ]----------------------------- |
|||
%s |
|||
-----------------------------------------------------` |
|||
|
|||
const logReqErrMsg = `DEBUG ERROR: Request %s/%s: |
|||
---[ REQUEST DUMP ERROR ]----------------------------- |
|||
%s |
|||
-----------------------------------------------------` |
|||
|
|||
func logRequest(r *request.Request) { |
|||
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) |
|||
dumpedBody, err := httputil.DumpRequestOut(r.HTTPRequest, logBody) |
|||
if err != nil { |
|||
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg, r.ClientInfo.ServiceName, r.Operation.Name, err)) |
|||
return |
|||
} |
|||
|
|||
if logBody { |
|||
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
|
|||
// Body as a NoOpCloser and will not be reset after read by the HTTP
|
|||
// client reader.
|
|||
r.ResetBody() |
|||
} |
|||
|
|||
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ClientInfo.ServiceName, r.Operation.Name, string(dumpedBody))) |
|||
} |
|||
|
|||
const logRespMsg = `DEBUG: Response %s/%s Details: |
|||
---[ RESPONSE ]-------------------------------------- |
|||
%s |
|||
-----------------------------------------------------` |
|||
|
|||
const logRespErrMsg = `DEBUG ERROR: Response %s/%s: |
|||
---[ RESPONSE DUMP ERROR ]----------------------------- |
|||
%s |
|||
-----------------------------------------------------` |
|||
|
|||
func logResponse(r *request.Request) { |
|||
var msg = "no response data" |
|||
if r.HTTPResponse != nil { |
|||
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) |
|||
dumpedBody, err := httputil.DumpResponse(r.HTTPResponse, logBody) |
|||
if err != nil { |
|||
r.Config.Logger.Log(fmt.Sprintf(logRespErrMsg, r.ClientInfo.ServiceName, r.Operation.Name, err)) |
|||
return |
|||
} |
|||
|
|||
msg = string(dumpedBody) |
|||
} else if r.Error != nil { |
|||
msg = r.Error.Error() |
|||
} |
|||
r.Config.Logger.Log(fmt.Sprintf(logRespMsg, r.ClientInfo.ServiceName, r.Operation.Name, msg)) |
|||
} |
@ -0,0 +1,90 @@ |
|||
package client |
|||
|
|||
import ( |
|||
"math/rand" |
|||
"sync" |
|||
"time" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws/request" |
|||
) |
|||
|
|||
// DefaultRetryer implements basic retry logic using exponential backoff for
|
|||
// most services. If you want to implement custom retry logic, implement the
|
|||
// request.Retryer interface or create a structure type that composes this
|
|||
// struct and override the specific methods. For example, to override only
|
|||
// the MaxRetries method:
|
|||
//
|
|||
// type retryer struct {
|
|||
// service.DefaultRetryer
|
|||
// }
|
|||
//
|
|||
// // This implementation always has 100 max retries
|
|||
// func (d retryer) MaxRetries() uint { return 100 }
|
|||
type DefaultRetryer struct { |
|||
NumMaxRetries int |
|||
} |
|||
|
|||
// MaxRetries returns the number of maximum returns the service will use to make
|
|||
// an individual API request.
|
|||
func (d DefaultRetryer) MaxRetries() int { |
|||
return d.NumMaxRetries |
|||
} |
|||
|
|||
var seededRand = rand.New(&lockedSource{src: rand.NewSource(time.Now().UnixNano())}) |
|||
|
|||
// RetryRules returns the delay duration before retrying this request again
|
|||
func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration { |
|||
// Set the upper limit of delay in retrying at ~five minutes
|
|||
minTime := 30 |
|||
throttle := d.shouldThrottle(r) |
|||
if throttle { |
|||
minTime = 500 |
|||
} |
|||
|
|||
retryCount := r.RetryCount |
|||
if retryCount > 13 { |
|||
retryCount = 13 |
|||
} else if throttle && retryCount > 8 { |
|||
retryCount = 8 |
|||
} |
|||
|
|||
delay := (1 << uint(retryCount)) * (seededRand.Intn(minTime) + minTime) |
|||
return time.Duration(delay) * time.Millisecond |
|||
} |
|||
|
|||
// ShouldRetry returns true if the request should be retried.
|
|||
func (d DefaultRetryer) ShouldRetry(r *request.Request) bool { |
|||
if r.HTTPResponse.StatusCode >= 500 { |
|||
return true |
|||
} |
|||
return r.IsErrorRetryable() || d.shouldThrottle(r) |
|||
} |
|||
|
|||
// ShouldThrottle returns true if the request should be throttled.
|
|||
func (d DefaultRetryer) shouldThrottle(r *request.Request) bool { |
|||
if r.HTTPResponse.StatusCode == 502 || |
|||
r.HTTPResponse.StatusCode == 503 || |
|||
r.HTTPResponse.StatusCode == 504 { |
|||
return true |
|||
} |
|||
return r.IsErrorThrottle() |
|||
} |
|||
|
|||
// lockedSource is a thread-safe implementation of rand.Source
|
|||
type lockedSource struct { |
|||
lk sync.Mutex |
|||
src rand.Source |
|||
} |
|||
|
|||
func (r *lockedSource) Int63() (n int64) { |
|||
r.lk.Lock() |
|||
n = r.src.Int63() |
|||
r.lk.Unlock() |
|||
return |
|||
} |
|||
|
|||
func (r *lockedSource) Seed(seed int64) { |
|||
r.lk.Lock() |
|||
r.src.Seed(seed) |
|||
r.lk.Unlock() |
|||
} |
@ -0,0 +1,12 @@ |
|||
package metadata |
|||
|
|||
// ClientInfo wraps immutable data from the client.Client structure.
|
|||
type ClientInfo struct { |
|||
ServiceName string |
|||
APIVersion string |
|||
Endpoint string |
|||
SigningName string |
|||
SigningRegion string |
|||
JSONVersion string |
|||
TargetPrefix string |
|||
} |
@ -0,0 +1,422 @@ |
|||
package aws |
|||
|
|||
import ( |
|||
"net/http" |
|||
"time" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws/credentials" |
|||
) |
|||
|
|||
// UseServiceDefaultRetries instructs the config to use the service's own
|
|||
// default number of retries. This will be the default action if
|
|||
// Config.MaxRetries is nil also.
|
|||
const UseServiceDefaultRetries = -1 |
|||
|
|||
// RequestRetryer is an alias for a type that implements the request.Retryer
|
|||
// interface.
|
|||
type RequestRetryer interface{} |
|||
|
|||
// A Config provides service configuration for service clients. By default,
|
|||
// all clients will use the defaults.DefaultConfig tructure.
|
|||
//
|
|||
// // Create Session with MaxRetry configuration to be shared by multiple
|
|||
// // service clients.
|
|||
// sess, err := session.NewSession(&aws.Config{
|
|||
// MaxRetries: aws.Int(3),
|
|||
// })
|
|||
//
|
|||
// // Create S3 service client with a specific Region.
|
|||
// svc := s3.New(sess, &aws.Config{
|
|||
// Region: aws.String("us-west-2"),
|
|||
// })
|
|||
type Config struct { |
|||
// Enables verbose error printing of all credential chain errors.
|
|||
// Should be used when wanting to see all errors while attempting to
|
|||
// retrieve credentials.
|
|||
CredentialsChainVerboseErrors *bool |
|||
|
|||
// The credentials object to use when signing requests. Defaults to a
|
|||
// chain of credential providers to search for credentials in environment
|
|||
// variables, shared credential file, and EC2 Instance Roles.
|
|||
Credentials *credentials.Credentials |
|||
|
|||
// An optional endpoint URL (hostname only or fully qualified URI)
|
|||
// that overrides the default generated endpoint for a client. Set this
|
|||
// to `""` to use the default generated endpoint.
|
|||
//
|
|||
// @note You must still provide a `Region` value when specifying an
|
|||
// endpoint for a client.
|
|||
Endpoint *string |
|||
|
|||
// The region to send requests to. This parameter is required and must
|
|||
// be configured globally or on a per-client basis unless otherwise
|
|||
// noted. A full list of regions is found in the "Regions and Endpoints"
|
|||
// document.
|
|||
//
|
|||
// @see http://docs.aws.amazon.com/general/latest/gr/rande.html
|
|||
// AWS Regions and Endpoints
|
|||
Region *string |
|||
|
|||
// Set this to `true` to disable SSL when sending requests. Defaults
|
|||
// to `false`.
|
|||
DisableSSL *bool |
|||
|
|||
// The HTTP client to use when sending requests. Defaults to
|
|||
// `http.DefaultClient`.
|
|||
HTTPClient *http.Client |
|||
|
|||
// An integer value representing the logging level. The default log level
|
|||
// is zero (LogOff), which represents no logging. To enable logging set
|
|||
// to a LogLevel Value.
|
|||
LogLevel *LogLevelType |
|||
|
|||
// The logger writer interface to write logging messages to. Defaults to
|
|||
// standard out.
|
|||
Logger Logger |
|||
|
|||
// The maximum number of times that a request will be retried for failures.
|
|||
// Defaults to -1, which defers the max retry setting to the service
|
|||
// specific configuration.
|
|||
MaxRetries *int |
|||
|
|||
// Retryer guides how HTTP requests should be retried in case of
|
|||
// recoverable failures.
|
|||
//
|
|||
// When nil or the value does not implement the request.Retryer interface,
|
|||
// the request.DefaultRetryer will be used.
|
|||
//
|
|||
// When both Retryer and MaxRetries are non-nil, the former is used and
|
|||
// the latter ignored.
|
|||
//
|
|||
// To set the Retryer field in a type-safe manner and with chaining, use
|
|||
// the request.WithRetryer helper function:
|
|||
//
|
|||
// cfg := request.WithRetryer(aws.NewConfig(), myRetryer)
|
|||
//
|
|||
Retryer RequestRetryer |
|||
|
|||
// Disables semantic parameter validation, which validates input for
|
|||
// missing required fields and/or other semantic request input errors.
|
|||
DisableParamValidation *bool |
|||
|
|||
// Disables the computation of request and response checksums, e.g.,
|
|||
// CRC32 checksums in Amazon DynamoDB.
|
|||
DisableComputeChecksums *bool |
|||
|
|||
// Set this to `true` to force the request to use path-style addressing,
|
|||
// i.e., `http://s3.amazonaws.com/BUCKET/KEY`. By default, the S3 client
|
|||
// will use virtual hosted bucket addressing when possible
|
|||
// (`http://BUCKET.s3.amazonaws.com/KEY`).
|
|||
//
|
|||
// @note This configuration option is specific to the Amazon S3 service.
|
|||
// @see http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
|
|||
// Amazon S3: Virtual Hosting of Buckets
|
|||
S3ForcePathStyle *bool |
|||
|
|||
// Set this to `true` to disable the SDK adding the `Expect: 100-Continue`
|
|||
// header to PUT requests over 2MB of content. 100-Continue instructs the
|
|||
// HTTP client not to send the body until the service responds with a
|
|||
// `continue` status. This is useful to prevent sending the request body
|
|||
// until after the request is authenticated, and validated.
|
|||
//
|
|||
// http://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectPUT.html
|
|||
//
|
|||
// 100-Continue is only enabled for Go 1.6 and above. See `http.Transport`'s
|
|||
// `ExpectContinueTimeout` for information on adjusting the continue wait
|
|||
// timeout. https://golang.org/pkg/net/http/#Transport
|
|||
//
|
|||
// You should use this flag to disble 100-Continue if you experience issues
|
|||
// with proxies or third party S3 compatible services.
|
|||
S3Disable100Continue *bool |
|||
|
|||
// Set this to `true` to enable S3 Accelerate feature. For all operations
|
|||
// compatible with S3 Accelerate will use the accelerate endpoint for
|
|||
// requests. Requests not compatible will fall back to normal S3 requests.
|
|||
//
|
|||
// The bucket must be enable for accelerate to be used with S3 client with
|
|||
// accelerate enabled. If the bucket is not enabled for accelerate an error
|
|||
// will be returned. The bucket name must be DNS compatible to also work
|
|||
// with accelerate.
|
|||
//
|
|||
// Not compatible with UseDualStack requests will fail if both flags are
|
|||
// specified.
|
|||
S3UseAccelerate *bool |
|||
|
|||
// Set this to `true` to disable the EC2Metadata client from overriding the
|
|||
// default http.Client's Timeout. This is helpful if you do not want the
|
|||
// EC2Metadata client to create a new http.Client. This options is only
|
|||
// meaningful if you're not already using a custom HTTP client with the
|
|||
// SDK. Enabled by default.
|
|||
//
|
|||
// Must be set and provided to the session.NewSession() in order to disable
|
|||
// the EC2Metadata overriding the timeout for default credentials chain.
|
|||
//
|
|||
// Example:
|
|||
// sess, err := session.NewSession(aws.NewConfig().WithEC2MetadataDiableTimeoutOverride(true))
|
|||
//
|
|||
// svc := s3.New(sess)
|
|||
//
|
|||
EC2MetadataDisableTimeoutOverride *bool |
|||
|
|||
// Instructs the endpiont to be generated for a service client to
|
|||
// be the dual stack endpoint. The dual stack endpoint will support
|
|||
// both IPv4 and IPv6 addressing.
|
|||
//
|
|||
// Setting this for a service which does not support dual stack will fail
|
|||
// to make requets. It is not recommended to set this value on the session
|
|||
// as it will apply to all service clients created with the session. Even
|
|||
// services which don't support dual stack endpoints.
|
|||
//
|
|||
// If the Endpoint config value is also provided the UseDualStack flag
|
|||
// will be ignored.
|
|||
//
|
|||
// Only supported with.
|
|||
//
|
|||
// sess, err := session.NewSession()
|
|||
//
|
|||
// svc := s3.New(sess, &aws.Config{
|
|||
// UseDualStack: aws.Bool(true),
|
|||
// })
|
|||
UseDualStack *bool |
|||
|
|||
// SleepDelay is an override for the func the SDK will call when sleeping
|
|||
// during the lifecycle of a request. Specifically this will be used for
|
|||
// request delays. This value should only be used for testing. To adjust
|
|||
// the delay of a request see the aws/client.DefaultRetryer and
|
|||
// aws/request.Retryer.
|
|||
SleepDelay func(time.Duration) |
|||
} |
|||
|
|||
// NewConfig returns a new Config pointer that can be chained with builder
|
|||
// methods to set multiple configuration values inline without using pointers.
|
|||
//
|
|||
// // Create Session with MaxRetry configuration to be shared by multiple
|
|||
// // service clients.
|
|||
// sess, err := session.NewSession(aws.NewConfig().
|
|||
// WithMaxRetries(3),
|
|||
// )
|
|||
//
|
|||
// // Create S3 service client with a specific Region.
|
|||
// svc := s3.New(sess, aws.NewConfig().
|
|||
// WithRegion("us-west-2"),
|
|||
// )
|
|||
func NewConfig() *Config { |
|||
return &Config{} |
|||
} |
|||
|
|||
// WithCredentialsChainVerboseErrors sets a config verbose errors boolean and returning
|
|||
// a Config pointer.
|
|||
func (c *Config) WithCredentialsChainVerboseErrors(verboseErrs bool) *Config { |
|||
c.CredentialsChainVerboseErrors = &verboseErrs |
|||
return c |
|||
} |
|||
|
|||
// WithCredentials sets a config Credentials value returning a Config pointer
|
|||
// for chaining.
|
|||
func (c *Config) WithCredentials(creds *credentials.Credentials) *Config { |
|||
c.Credentials = creds |
|||
return c |
|||
} |
|||
|
|||
// WithEndpoint sets a config Endpoint value returning a Config pointer for
|
|||
// chaining.
|
|||
func (c *Config) WithEndpoint(endpoint string) *Config { |
|||
c.Endpoint = &endpoint |
|||
return c |
|||
} |
|||
|
|||
// WithRegion sets a config Region value returning a Config pointer for
|
|||
// chaining.
|
|||
func (c *Config) WithRegion(region string) *Config { |
|||
c.Region = ®ion |
|||
return c |
|||
} |
|||
|
|||
// WithDisableSSL sets a config DisableSSL value returning a Config pointer
|
|||
// for chaining.
|
|||
func (c *Config) WithDisableSSL(disable bool) *Config { |
|||
c.DisableSSL = &disable |
|||
return c |
|||
} |
|||
|
|||
// WithHTTPClient sets a config HTTPClient value returning a Config pointer
|
|||
// for chaining.
|
|||
func (c *Config) WithHTTPClient(client *http.Client) *Config { |
|||
c.HTTPClient = client |
|||
return c |
|||
} |
|||
|
|||
// WithMaxRetries sets a config MaxRetries value returning a Config pointer
|
|||
// for chaining.
|
|||
func (c *Config) WithMaxRetries(max int) *Config { |
|||
c.MaxRetries = &max |
|||
return c |
|||
} |
|||
|
|||
// WithDisableParamValidation sets a config DisableParamValidation value
|
|||
// returning a Config pointer for chaining.
|
|||
func (c *Config) WithDisableParamValidation(disable bool) *Config { |
|||
c.DisableParamValidation = &disable |
|||
return c |
|||
} |
|||
|
|||
// WithDisableComputeChecksums sets a config DisableComputeChecksums value
|
|||
// returning a Config pointer for chaining.
|
|||
func (c *Config) WithDisableComputeChecksums(disable bool) *Config { |
|||
c.DisableComputeChecksums = &disable |
|||
return c |
|||
} |
|||
|
|||
// WithLogLevel sets a config LogLevel value returning a Config pointer for
|
|||
// chaining.
|
|||
func (c *Config) WithLogLevel(level LogLevelType) *Config { |
|||
c.LogLevel = &level |
|||
return c |
|||
} |
|||
|
|||
// WithLogger sets a config Logger value returning a Config pointer for
|
|||
// chaining.
|
|||
func (c *Config) WithLogger(logger Logger) *Config { |
|||
c.Logger = logger |
|||
return c |
|||
} |
|||
|
|||
// WithS3ForcePathStyle sets a config S3ForcePathStyle value returning a Config
|
|||
// pointer for chaining.
|
|||
func (c *Config) WithS3ForcePathStyle(force bool) *Config { |
|||
c.S3ForcePathStyle = &force |
|||
return c |
|||
} |
|||
|
|||
// WithS3Disable100Continue sets a config S3Disable100Continue value returning
|
|||
// a Config pointer for chaining.
|
|||
func (c *Config) WithS3Disable100Continue(disable bool) *Config { |
|||
c.S3Disable100Continue = &disable |
|||
return c |
|||
} |
|||
|
|||
// WithS3UseAccelerate sets a config S3UseAccelerate value returning a Config
|
|||
// pointer for chaining.
|
|||
func (c *Config) WithS3UseAccelerate(enable bool) *Config { |
|||
c.S3UseAccelerate = &enable |
|||
return c |
|||
} |
|||
|
|||
// WithUseDualStack sets a config UseDualStack value returning a Config
|
|||
// pointer for chaining.
|
|||
func (c *Config) WithUseDualStack(enable bool) *Config { |
|||
c.UseDualStack = &enable |
|||
return c |
|||
} |
|||
|
|||
// WithEC2MetadataDisableTimeoutOverride sets a config EC2MetadataDisableTimeoutOverride value
|
|||
// returning a Config pointer for chaining.
|
|||
func (c *Config) WithEC2MetadataDisableTimeoutOverride(enable bool) *Config { |
|||
c.EC2MetadataDisableTimeoutOverride = &enable |
|||
return c |
|||
} |
|||
|
|||
// WithSleepDelay overrides the function used to sleep while waiting for the
|
|||
// next retry. Defaults to time.Sleep.
|
|||
func (c *Config) WithSleepDelay(fn func(time.Duration)) *Config { |
|||
c.SleepDelay = fn |
|||
return c |
|||
} |
|||
|
|||
// MergeIn merges the passed in configs into the existing config object.
|
|||
func (c *Config) MergeIn(cfgs ...*Config) { |
|||
for _, other := range cfgs { |
|||
mergeInConfig(c, other) |
|||
} |
|||
} |
|||
|
|||
func mergeInConfig(dst *Config, other *Config) { |
|||
if other == nil { |
|||
return |
|||
} |
|||
|
|||
if other.CredentialsChainVerboseErrors != nil { |
|||
dst.CredentialsChainVerboseErrors = other.CredentialsChainVerboseErrors |
|||
} |
|||
|
|||
if other.Credentials != nil { |
|||
dst.Credentials = other.Credentials |
|||
} |
|||
|
|||
if other.Endpoint != nil { |
|||
dst.Endpoint = other.Endpoint |
|||
} |
|||
|
|||
if other.Region != nil { |
|||
dst.Region = other.Region |
|||
} |
|||
|
|||
if other.DisableSSL != nil { |
|||
dst.DisableSSL = other.DisableSSL |
|||
} |
|||
|
|||
if other.HTTPClient != nil { |
|||
dst.HTTPClient = other.HTTPClient |
|||
} |
|||
|
|||
if other.LogLevel != nil { |
|||
dst.LogLevel = other.LogLevel |
|||
} |
|||
|
|||
if other.Logger != nil { |
|||
dst.Logger = other.Logger |
|||
} |
|||
|
|||
if other.MaxRetries != nil { |
|||
dst.MaxRetries = other.MaxRetries |
|||
} |
|||
|
|||
if other.Retryer != nil { |
|||
dst.Retryer = other.Retryer |
|||
} |
|||
|
|||
if other.DisableParamValidation != nil { |
|||
dst.DisableParamValidation = other.DisableParamValidation |
|||
} |
|||
|
|||
if other.DisableComputeChecksums != nil { |
|||
dst.DisableComputeChecksums = other.DisableComputeChecksums |
|||
} |
|||
|
|||
if other.S3ForcePathStyle != nil { |
|||
dst.S3ForcePathStyle = other.S3ForcePathStyle |
|||
} |
|||
|
|||
if other.S3Disable100Continue != nil { |
|||
dst.S3Disable100Continue = other.S3Disable100Continue |
|||
} |
|||
|
|||
if other.S3UseAccelerate != nil { |
|||
dst.S3UseAccelerate = other.S3UseAccelerate |
|||
} |
|||
|
|||
if other.UseDualStack != nil { |
|||
dst.UseDualStack = other.UseDualStack |
|||
} |
|||
|
|||
if other.EC2MetadataDisableTimeoutOverride != nil { |
|||
dst.EC2MetadataDisableTimeoutOverride = other.EC2MetadataDisableTimeoutOverride |
|||
} |
|||
|
|||
if other.SleepDelay != nil { |
|||
dst.SleepDelay = other.SleepDelay |
|||
} |
|||
} |
|||
|
|||
// Copy will return a shallow copy of the Config object. If any additional
|
|||
// configurations are provided they will be merged into the new config returned.
|
|||
func (c *Config) Copy(cfgs ...*Config) *Config { |
|||
dst := &Config{} |
|||
dst.MergeIn(c) |
|||
|
|||
for _, cfg := range cfgs { |
|||
dst.MergeIn(cfg) |
|||
} |
|||
|
|||
return dst |
|||
} |
@ -0,0 +1,369 @@ |
|||
package aws |
|||
|
|||
import "time" |
|||
|
|||
// String returns a pointer to the string value passed in.
|
|||
func String(v string) *string { |
|||
return &v |
|||
} |
|||
|
|||
// StringValue returns the value of the string pointer passed in or
|
|||
// "" if the pointer is nil.
|
|||
func StringValue(v *string) string { |
|||
if v != nil { |
|||
return *v |
|||
} |
|||
return "" |
|||
} |
|||
|
|||
// StringSlice converts a slice of string values into a slice of
|
|||
// string pointers
|
|||
func StringSlice(src []string) []*string { |
|||
dst := make([]*string, len(src)) |
|||
for i := 0; i < len(src); i++ { |
|||
dst[i] = &(src[i]) |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// StringValueSlice converts a slice of string pointers into a slice of
|
|||
// string values
|
|||
func StringValueSlice(src []*string) []string { |
|||
dst := make([]string, len(src)) |
|||
for i := 0; i < len(src); i++ { |
|||
if src[i] != nil { |
|||
dst[i] = *(src[i]) |
|||
} |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// StringMap converts a string map of string values into a string
|
|||
// map of string pointers
|
|||
func StringMap(src map[string]string) map[string]*string { |
|||
dst := make(map[string]*string) |
|||
for k, val := range src { |
|||
v := val |
|||
dst[k] = &v |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// StringValueMap converts a string map of string pointers into a string
|
|||
// map of string values
|
|||
func StringValueMap(src map[string]*string) map[string]string { |
|||
dst := make(map[string]string) |
|||
for k, val := range src { |
|||
if val != nil { |
|||
dst[k] = *val |
|||
} |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// Bool returns a pointer to the bool value passed in.
|
|||
func Bool(v bool) *bool { |
|||
return &v |
|||
} |
|||
|
|||
// BoolValue returns the value of the bool pointer passed in or
|
|||
// false if the pointer is nil.
|
|||
func BoolValue(v *bool) bool { |
|||
if v != nil { |
|||
return *v |
|||
} |
|||
return false |
|||
} |
|||
|
|||
// BoolSlice converts a slice of bool values into a slice of
|
|||
// bool pointers
|
|||
func BoolSlice(src []bool) []*bool { |
|||
dst := make([]*bool, len(src)) |
|||
for i := 0; i < len(src); i++ { |
|||
dst[i] = &(src[i]) |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// BoolValueSlice converts a slice of bool pointers into a slice of
|
|||
// bool values
|
|||
func BoolValueSlice(src []*bool) []bool { |
|||
dst := make([]bool, len(src)) |
|||
for i := 0; i < len(src); i++ { |
|||
if src[i] != nil { |
|||
dst[i] = *(src[i]) |
|||
} |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// BoolMap converts a string map of bool values into a string
|
|||
// map of bool pointers
|
|||
func BoolMap(src map[string]bool) map[string]*bool { |
|||
dst := make(map[string]*bool) |
|||
for k, val := range src { |
|||
v := val |
|||
dst[k] = &v |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// BoolValueMap converts a string map of bool pointers into a string
|
|||
// map of bool values
|
|||
func BoolValueMap(src map[string]*bool) map[string]bool { |
|||
dst := make(map[string]bool) |
|||
for k, val := range src { |
|||
if val != nil { |
|||
dst[k] = *val |
|||
} |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// Int returns a pointer to the int value passed in.
|
|||
func Int(v int) *int { |
|||
return &v |
|||
} |
|||
|
|||
// IntValue returns the value of the int pointer passed in or
|
|||
// 0 if the pointer is nil.
|
|||
func IntValue(v *int) int { |
|||
if v != nil { |
|||
return *v |
|||
} |
|||
return 0 |
|||
} |
|||
|
|||
// IntSlice converts a slice of int values into a slice of
|
|||
// int pointers
|
|||
func IntSlice(src []int) []*int { |
|||
dst := make([]*int, len(src)) |
|||
for i := 0; i < len(src); i++ { |
|||
dst[i] = &(src[i]) |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// IntValueSlice converts a slice of int pointers into a slice of
|
|||
// int values
|
|||
func IntValueSlice(src []*int) []int { |
|||
dst := make([]int, len(src)) |
|||
for i := 0; i < len(src); i++ { |
|||
if src[i] != nil { |
|||
dst[i] = *(src[i]) |
|||
} |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// IntMap converts a string map of int values into a string
|
|||
// map of int pointers
|
|||
func IntMap(src map[string]int) map[string]*int { |
|||
dst := make(map[string]*int) |
|||
for k, val := range src { |
|||
v := val |
|||
dst[k] = &v |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// IntValueMap converts a string map of int pointers into a string
|
|||
// map of int values
|
|||
func IntValueMap(src map[string]*int) map[string]int { |
|||
dst := make(map[string]int) |
|||
for k, val := range src { |
|||
if val != nil { |
|||
dst[k] = *val |
|||
} |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// Int64 returns a pointer to the int64 value passed in.
|
|||
func Int64(v int64) *int64 { |
|||
return &v |
|||
} |
|||
|
|||
// Int64Value returns the value of the int64 pointer passed in or
|
|||
// 0 if the pointer is nil.
|
|||
func Int64Value(v *int64) int64 { |
|||
if v != nil { |
|||
return *v |
|||
} |
|||
return 0 |
|||
} |
|||
|
|||
// Int64Slice converts a slice of int64 values into a slice of
|
|||
// int64 pointers
|
|||
func Int64Slice(src []int64) []*int64 { |
|||
dst := make([]*int64, len(src)) |
|||
for i := 0; i < len(src); i++ { |
|||
dst[i] = &(src[i]) |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// Int64ValueSlice converts a slice of int64 pointers into a slice of
|
|||
// int64 values
|
|||
func Int64ValueSlice(src []*int64) []int64 { |
|||
dst := make([]int64, len(src)) |
|||
for i := 0; i < len(src); i++ { |
|||
if src[i] != nil { |
|||
dst[i] = *(src[i]) |
|||
} |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// Int64Map converts a string map of int64 values into a string
|
|||
// map of int64 pointers
|
|||
func Int64Map(src map[string]int64) map[string]*int64 { |
|||
dst := make(map[string]*int64) |
|||
for k, val := range src { |
|||
v := val |
|||
dst[k] = &v |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// Int64ValueMap converts a string map of int64 pointers into a string
|
|||
// map of int64 values
|
|||
func Int64ValueMap(src map[string]*int64) map[string]int64 { |
|||
dst := make(map[string]int64) |
|||
for k, val := range src { |
|||
if val != nil { |
|||
dst[k] = *val |
|||
} |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// Float64 returns a pointer to the float64 value passed in.
|
|||
func Float64(v float64) *float64 { |
|||
return &v |
|||
} |
|||
|
|||
// Float64Value returns the value of the float64 pointer passed in or
|
|||
// 0 if the pointer is nil.
|
|||
func Float64Value(v *float64) float64 { |
|||
if v != nil { |
|||
return *v |
|||
} |
|||
return 0 |
|||
} |
|||
|
|||
// Float64Slice converts a slice of float64 values into a slice of
|
|||
// float64 pointers
|
|||
func Float64Slice(src []float64) []*float64 { |
|||
dst := make([]*float64, len(src)) |
|||
for i := 0; i < len(src); i++ { |
|||
dst[i] = &(src[i]) |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// Float64ValueSlice converts a slice of float64 pointers into a slice of
|
|||
// float64 values
|
|||
func Float64ValueSlice(src []*float64) []float64 { |
|||
dst := make([]float64, len(src)) |
|||
for i := 0; i < len(src); i++ { |
|||
if src[i] != nil { |
|||
dst[i] = *(src[i]) |
|||
} |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// Float64Map converts a string map of float64 values into a string
|
|||
// map of float64 pointers
|
|||
func Float64Map(src map[string]float64) map[string]*float64 { |
|||
dst := make(map[string]*float64) |
|||
for k, val := range src { |
|||
v := val |
|||
dst[k] = &v |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// Float64ValueMap converts a string map of float64 pointers into a string
|
|||
// map of float64 values
|
|||
func Float64ValueMap(src map[string]*float64) map[string]float64 { |
|||
dst := make(map[string]float64) |
|||
for k, val := range src { |
|||
if val != nil { |
|||
dst[k] = *val |
|||
} |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// Time returns a pointer to the time.Time value passed in.
|
|||
func Time(v time.Time) *time.Time { |
|||
return &v |
|||
} |
|||
|
|||
// TimeValue returns the value of the time.Time pointer passed in or
|
|||
// time.Time{} if the pointer is nil.
|
|||
func TimeValue(v *time.Time) time.Time { |
|||
if v != nil { |
|||
return *v |
|||
} |
|||
return time.Time{} |
|||
} |
|||
|
|||
// TimeUnixMilli returns a Unix timestamp in milliseconds from "January 1, 1970 UTC".
|
|||
// The result is undefined if the Unix time cannot be represented by an int64.
|
|||
// Which includes calling TimeUnixMilli on a zero Time is undefined.
|
|||
//
|
|||
// This utility is useful for service API's such as CloudWatch Logs which require
|
|||
// their unix time values to be in milliseconds.
|
|||
//
|
|||
// See Go stdlib https://golang.org/pkg/time/#Time.UnixNano for more information.
|
|||
func TimeUnixMilli(t time.Time) int64 { |
|||
return t.UnixNano() / int64(time.Millisecond/time.Nanosecond) |
|||
} |
|||
|
|||
// TimeSlice converts a slice of time.Time values into a slice of
|
|||
// time.Time pointers
|
|||
func TimeSlice(src []time.Time) []*time.Time { |
|||
dst := make([]*time.Time, len(src)) |
|||
for i := 0; i < len(src); i++ { |
|||
dst[i] = &(src[i]) |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// TimeValueSlice converts a slice of time.Time pointers into a slice of
|
|||
// time.Time values
|
|||
func TimeValueSlice(src []*time.Time) []time.Time { |
|||
dst := make([]time.Time, len(src)) |
|||
for i := 0; i < len(src); i++ { |
|||
if src[i] != nil { |
|||
dst[i] = *(src[i]) |
|||
} |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// TimeMap converts a string map of time.Time values into a string
|
|||
// map of time.Time pointers
|
|||
func TimeMap(src map[string]time.Time) map[string]*time.Time { |
|||
dst := make(map[string]*time.Time) |
|||
for k, val := range src { |
|||
v := val |
|||
dst[k] = &v |
|||
} |
|||
return dst |
|||
} |
|||
|
|||
// TimeValueMap converts a string map of time.Time pointers into a string
|
|||
// map of time.Time values
|
|||
func TimeValueMap(src map[string]*time.Time) map[string]time.Time { |
|||
dst := make(map[string]time.Time) |
|||
for k, val := range src { |
|||
if val != nil { |
|||
dst[k] = *val |
|||
} |
|||
} |
|||
return dst |
|||
} |
@ -0,0 +1,182 @@ |
|||
package corehandlers |
|||
|
|||
import ( |
|||
"bytes" |
|||
"fmt" |
|||
"io" |
|||
"io/ioutil" |
|||
"net/http" |
|||
"net/url" |
|||
"regexp" |
|||
"runtime" |
|||
"strconv" |
|||
"time" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws" |
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
"github.com/aws/aws-sdk-go/aws/credentials" |
|||
"github.com/aws/aws-sdk-go/aws/request" |
|||
) |
|||
|
|||
// Interface for matching types which also have a Len method.
|
|||
type lener interface { |
|||
Len() int |
|||
} |
|||
|
|||
// BuildContentLengthHandler builds the content length of a request based on the body,
|
|||
// or will use the HTTPRequest.Header's "Content-Length" if defined. If unable
|
|||
// to determine request body length and no "Content-Length" was specified it will panic.
|
|||
//
|
|||
// The Content-Length will only be aded to the request if the length of the body
|
|||
// is greater than 0. If the body is empty or the current `Content-Length`
|
|||
// header is <= 0, the header will also be stripped.
|
|||
var BuildContentLengthHandler = request.NamedHandler{Name: "core.BuildContentLengthHandler", Fn: func(r *request.Request) { |
|||
var length int64 |
|||
|
|||
if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" { |
|||
length, _ = strconv.ParseInt(slength, 10, 64) |
|||
} else { |
|||
switch body := r.Body.(type) { |
|||
case nil: |
|||
length = 0 |
|||
case lener: |
|||
length = int64(body.Len()) |
|||
case io.Seeker: |
|||
r.BodyStart, _ = body.Seek(0, 1) |
|||
end, _ := body.Seek(0, 2) |
|||
body.Seek(r.BodyStart, 0) // make sure to seek back to original location
|
|||
length = end - r.BodyStart |
|||
default: |
|||
panic("Cannot get length of body, must provide `ContentLength`") |
|||
} |
|||
} |
|||
|
|||
if length > 0 { |
|||
r.HTTPRequest.ContentLength = length |
|||
r.HTTPRequest.Header.Set("Content-Length", fmt.Sprintf("%d", length)) |
|||
} else { |
|||
r.HTTPRequest.ContentLength = 0 |
|||
r.HTTPRequest.Header.Del("Content-Length") |
|||
} |
|||
}} |
|||
|
|||
// SDKVersionUserAgentHandler is a request handler for adding the SDK Version to the user agent.
|
|||
var SDKVersionUserAgentHandler = request.NamedHandler{ |
|||
Name: "core.SDKVersionUserAgentHandler", |
|||
Fn: request.MakeAddToUserAgentHandler(aws.SDKName, aws.SDKVersion, |
|||
runtime.Version(), runtime.GOOS, runtime.GOARCH), |
|||
} |
|||
|
|||
var reStatusCode = regexp.MustCompile(`^(\d{3})`) |
|||
|
|||
// ValidateReqSigHandler is a request handler to ensure that the request's
|
|||
// signature doesn't expire before it is sent. This can happen when a request
|
|||
// is built and signed signficantly before it is sent. Or signficant delays
|
|||
// occur whne retrying requests that would cause the signature to expire.
|
|||
var ValidateReqSigHandler = request.NamedHandler{ |
|||
Name: "core.ValidateReqSigHandler", |
|||
Fn: func(r *request.Request) { |
|||
// Unsigned requests are not signed
|
|||
if r.Config.Credentials == credentials.AnonymousCredentials { |
|||
return |
|||
} |
|||
|
|||
signedTime := r.Time |
|||
if !r.LastSignedAt.IsZero() { |
|||
signedTime = r.LastSignedAt |
|||
} |
|||
|
|||
// 10 minutes to allow for some clock skew/delays in transmission.
|
|||
// Would be improved with aws/aws-sdk-go#423
|
|||
if signedTime.Add(10 * time.Minute).After(time.Now()) { |
|||
return |
|||
} |
|||
|
|||
fmt.Println("request expired, resigning") |
|||
r.Sign() |
|||
}, |
|||
} |
|||
|
|||
// SendHandler is a request handler to send service request using HTTP client.
|
|||
var SendHandler = request.NamedHandler{Name: "core.SendHandler", Fn: func(r *request.Request) { |
|||
var err error |
|||
r.HTTPResponse, err = r.Config.HTTPClient.Do(r.HTTPRequest) |
|||
if err != nil { |
|||
// Prevent leaking if an HTTPResponse was returned. Clean up
|
|||
// the body.
|
|||
if r.HTTPResponse != nil { |
|||
r.HTTPResponse.Body.Close() |
|||
} |
|||
// Capture the case where url.Error is returned for error processing
|
|||
// response. e.g. 301 without location header comes back as string
|
|||
// error and r.HTTPResponse is nil. Other url redirect errors will
|
|||
// comeback in a similar method.
|
|||
if e, ok := err.(*url.Error); ok && e.Err != nil { |
|||
if s := reStatusCode.FindStringSubmatch(e.Err.Error()); s != nil { |
|||
code, _ := strconv.ParseInt(s[1], 10, 64) |
|||
r.HTTPResponse = &http.Response{ |
|||
StatusCode: int(code), |
|||
Status: http.StatusText(int(code)), |
|||
Body: ioutil.NopCloser(bytes.NewReader([]byte{})), |
|||
} |
|||
return |
|||
} |
|||
} |
|||
if r.HTTPResponse == nil { |
|||
// Add a dummy request response object to ensure the HTTPResponse
|
|||
// value is consistent.
|
|||
r.HTTPResponse = &http.Response{ |
|||
StatusCode: int(0), |
|||
Status: http.StatusText(int(0)), |
|||
Body: ioutil.NopCloser(bytes.NewReader([]byte{})), |
|||
} |
|||
} |
|||
// Catch all other request errors.
|
|||
r.Error = awserr.New("RequestError", "send request failed", err) |
|||
r.Retryable = aws.Bool(true) // network errors are retryable
|
|||
} |
|||
}} |
|||
|
|||
// ValidateResponseHandler is a request handler to validate service response.
|
|||
var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseHandler", Fn: func(r *request.Request) { |
|||
if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 { |
|||
// this may be replaced by an UnmarshalError handler
|
|||
r.Error = awserr.New("UnknownError", "unknown error", nil) |
|||
} |
|||
}} |
|||
|
|||
// AfterRetryHandler performs final checks to determine if the request should
|
|||
// be retried and how long to delay.
|
|||
var AfterRetryHandler = request.NamedHandler{Name: "core.AfterRetryHandler", Fn: func(r *request.Request) { |
|||
// If one of the other handlers already set the retry state
|
|||
// we don't want to override it based on the service's state
|
|||
if r.Retryable == nil { |
|||
r.Retryable = aws.Bool(r.ShouldRetry(r)) |
|||
} |
|||
|
|||
if r.WillRetry() { |
|||
r.RetryDelay = r.RetryRules(r) |
|||
r.Config.SleepDelay(r.RetryDelay) |
|||
|
|||
// when the expired token exception occurs the credentials
|
|||
// need to be expired locally so that the next request to
|
|||
// get credentials will trigger a credentials refresh.
|
|||
if r.IsErrorExpired() { |
|||
r.Config.Credentials.Expire() |
|||
} |
|||
|
|||
r.RetryCount++ |
|||
r.Error = nil |
|||
} |
|||
}} |
|||
|
|||
// ValidateEndpointHandler is a request handler to validate a request had the
|
|||
// appropriate Region and Endpoint set. Will set r.Error if the endpoint or
|
|||
// region is not valid.
|
|||
var ValidateEndpointHandler = request.NamedHandler{Name: "core.ValidateEndpointHandler", Fn: func(r *request.Request) { |
|||
if r.ClientInfo.SigningRegion == "" && aws.StringValue(r.Config.Region) == "" { |
|||
r.Error = aws.ErrMissingRegion |
|||
} else if r.ClientInfo.Endpoint == "" { |
|||
r.Error = aws.ErrMissingEndpoint |
|||
} |
|||
}} |
@ -0,0 +1,17 @@ |
|||
package corehandlers |
|||
|
|||
import "github.com/aws/aws-sdk-go/aws/request" |
|||
|
|||
// ValidateParametersHandler is a request handler to validate the input parameters.
|
|||
// Validating parameters only has meaning if done prior to the request being sent.
|
|||
var ValidateParametersHandler = request.NamedHandler{Name: "core.ValidateParametersHandler", Fn: func(r *request.Request) { |
|||
if !r.ParamsFilled() { |
|||
return |
|||
} |
|||
|
|||
if v, ok := r.Params.(request.Validator); ok { |
|||
if err := v.Validate(); err != nil { |
|||
r.Error = err |
|||
} |
|||
} |
|||
}} |
@ -0,0 +1,100 @@ |
|||
package credentials |
|||
|
|||
import ( |
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
) |
|||
|
|||
var ( |
|||
// ErrNoValidProvidersFoundInChain Is returned when there are no valid
|
|||
// providers in the ChainProvider.
|
|||
//
|
|||
// This has been deprecated. For verbose error messaging set
|
|||
// aws.Config.CredentialsChainVerboseErrors to true
|
|||
//
|
|||
// @readonly
|
|||
ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders", |
|||
`no valid providers in chain. Deprecated. |
|||
For verbose messaging see aws.Config.CredentialsChainVerboseErrors`, |
|||
nil) |
|||
) |
|||
|
|||
// A ChainProvider will search for a provider which returns credentials
|
|||
// and cache that provider until Retrieve is called again.
|
|||
//
|
|||
// The ChainProvider provides a way of chaining multiple providers together
|
|||
// which will pick the first available using priority order of the Providers
|
|||
// in the list.
|
|||
//
|
|||
// If none of the Providers retrieve valid credentials Value, ChainProvider's
|
|||
// Retrieve() will return the error ErrNoValidProvidersFoundInChain.
|
|||
//
|
|||
// If a Provider is found which returns valid credentials Value ChainProvider
|
|||
// will cache that Provider for all calls to IsExpired(), until Retrieve is
|
|||
// called again.
|
|||
//
|
|||
// Example of ChainProvider to be used with an EnvProvider and EC2RoleProvider.
|
|||
// In this example EnvProvider will first check if any credentials are available
|
|||
// vai the environment variables. If there are none ChainProvider will check
|
|||
// the next Provider in the list, EC2RoleProvider in this case. If EC2RoleProvider
|
|||
// does not return any credentials ChainProvider will return the error
|
|||
// ErrNoValidProvidersFoundInChain
|
|||
//
|
|||
// creds := NewChainCredentials(
|
|||
// []Provider{
|
|||
// &EnvProvider{},
|
|||
// &EC2RoleProvider{
|
|||
// Client: ec2metadata.New(sess),
|
|||
// },
|
|||
// })
|
|||
//
|
|||
// // Usage of ChainCredentials with aws.Config
|
|||
// svc := ec2.New(&aws.Config{Credentials: creds})
|
|||
//
|
|||
type ChainProvider struct { |
|||
Providers []Provider |
|||
curr Provider |
|||
VerboseErrors bool |
|||
} |
|||
|
|||
// NewChainCredentials returns a pointer to a new Credentials object
|
|||
// wrapping a chain of providers.
|
|||
func NewChainCredentials(providers []Provider) *Credentials { |
|||
return NewCredentials(&ChainProvider{ |
|||
Providers: append([]Provider{}, providers...), |
|||
}) |
|||
} |
|||
|
|||
// Retrieve returns the credentials value or error if no provider returned
|
|||
// without error.
|
|||
//
|
|||
// If a provider is found it will be cached and any calls to IsExpired()
|
|||
// will return the expired state of the cached provider.
|
|||
func (c *ChainProvider) Retrieve() (Value, error) { |
|||
var errs []error |
|||
for _, p := range c.Providers { |
|||
creds, err := p.Retrieve() |
|||
if err == nil { |
|||
c.curr = p |
|||
return creds, nil |
|||
} |
|||
errs = append(errs, err) |
|||
} |
|||
c.curr = nil |
|||
|
|||
var err error |
|||
err = ErrNoValidProvidersFoundInChain |
|||
if c.VerboseErrors { |
|||
err = awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs) |
|||
} |
|||
return Value{}, err |
|||
} |
|||
|
|||
// IsExpired will returned the expired state of the currently cached provider
|
|||
// if there is one. If there is no current provider, true will be returned.
|
|||
func (c *ChainProvider) IsExpired() bool { |
|||
if c.curr != nil { |
|||
return c.curr.IsExpired() |
|||
} |
|||
|
|||
return true |
|||
} |
@ -0,0 +1,223 @@ |
|||
// Package credentials provides credential retrieval and management
|
|||
//
|
|||
// The Credentials is the primary method of getting access to and managing
|
|||
// credentials Values. Using dependency injection retrieval of the credential
|
|||
// values is handled by a object which satisfies the Provider interface.
|
|||
//
|
|||
// By default the Credentials.Get() will cache the successful result of a
|
|||
// Provider's Retrieve() until Provider.IsExpired() returns true. At which
|
|||
// point Credentials will call Provider's Retrieve() to get new credential Value.
|
|||
//
|
|||
// The Provider is responsible for determining when credentials Value have expired.
|
|||
// It is also important to note that Credentials will always call Retrieve the
|
|||
// first time Credentials.Get() is called.
|
|||
//
|
|||
// Example of using the environment variable credentials.
|
|||
//
|
|||
// creds := NewEnvCredentials()
|
|||
//
|
|||
// // Retrieve the credentials value
|
|||
// credValue, err := creds.Get()
|
|||
// if err != nil {
|
|||
// // handle error
|
|||
// }
|
|||
//
|
|||
// Example of forcing credentials to expire and be refreshed on the next Get().
|
|||
// This may be helpful to proactively expire credentials and refresh them sooner
|
|||
// than they would naturally expire on their own.
|
|||
//
|
|||
// creds := NewCredentials(&EC2RoleProvider{})
|
|||
// creds.Expire()
|
|||
// credsValue, err := creds.Get()
|
|||
// // New credentials will be retrieved instead of from cache.
|
|||
//
|
|||
//
|
|||
// Custom Provider
|
|||
//
|
|||
// Each Provider built into this package also provides a helper method to generate
|
|||
// a Credentials pointer setup with the provider. To use a custom Provider just
|
|||
// create a type which satisfies the Provider interface and pass it to the
|
|||
// NewCredentials method.
|
|||
//
|
|||
// type MyProvider struct{}
|
|||
// func (m *MyProvider) Retrieve() (Value, error) {...}
|
|||
// func (m *MyProvider) IsExpired() bool {...}
|
|||
//
|
|||
// creds := NewCredentials(&MyProvider{})
|
|||
// credValue, err := creds.Get()
|
|||
//
|
|||
package credentials |
|||
|
|||
import ( |
|||
"sync" |
|||
"time" |
|||
) |
|||
|
|||
// AnonymousCredentials is an empty Credential object that can be used as
|
|||
// dummy placeholder credentials for requests that do not need signed.
|
|||
//
|
|||
// This Credentials can be used to configure a service to not sign requests
|
|||
// when making service API calls. For example, when accessing public
|
|||
// s3 buckets.
|
|||
//
|
|||
// svc := s3.New(&aws.Config{Credentials: AnonymousCredentials})
|
|||
// // Access public S3 buckets.
|
|||
//
|
|||
// @readonly
|
|||
var AnonymousCredentials = NewStaticCredentials("", "", "") |
|||
|
|||
// A Value is the AWS credentials value for individual credential fields.
|
|||
type Value struct { |
|||
// AWS Access key ID
|
|||
AccessKeyID string |
|||
|
|||
// AWS Secret Access Key
|
|||
SecretAccessKey string |
|||
|
|||
// AWS Session Token
|
|||
SessionToken string |
|||
|
|||
// Provider used to get credentials
|
|||
ProviderName string |
|||
} |
|||
|
|||
// A Provider is the interface for any component which will provide credentials
|
|||
// Value. A provider is required to manage its own Expired state, and what to
|
|||
// be expired means.
|
|||
//
|
|||
// The Provider should not need to implement its own mutexes, because
|
|||
// that will be managed by Credentials.
|
|||
type Provider interface { |
|||
// Refresh returns nil if it successfully retrieved the value.
|
|||
// Error is returned if the value were not obtainable, or empty.
|
|||
Retrieve() (Value, error) |
|||
|
|||
// IsExpired returns if the credentials are no longer valid, and need
|
|||
// to be retrieved.
|
|||
IsExpired() bool |
|||
} |
|||
|
|||
// A Expiry provides shared expiration logic to be used by credentials
|
|||
// providers to implement expiry functionality.
|
|||
//
|
|||
// The best method to use this struct is as an anonymous field within the
|
|||
// provider's struct.
|
|||
//
|
|||
// Example:
|
|||
// type EC2RoleProvider struct {
|
|||
// Expiry
|
|||
// ...
|
|||
// }
|
|||
type Expiry struct { |
|||
// The date/time when to expire on
|
|||
expiration time.Time |
|||
|
|||
// If set will be used by IsExpired to determine the current time.
|
|||
// Defaults to time.Now if CurrentTime is not set. Available for testing
|
|||
// to be able to mock out the current time.
|
|||
CurrentTime func() time.Time |
|||
} |
|||
|
|||
// SetExpiration sets the expiration IsExpired will check when called.
|
|||
//
|
|||
// If window is greater than 0 the expiration time will be reduced by the
|
|||
// window value.
|
|||
//
|
|||
// Using a window is helpful to trigger credentials to expire sooner than
|
|||
// the expiration time given to ensure no requests are made with expired
|
|||
// tokens.
|
|||
func (e *Expiry) SetExpiration(expiration time.Time, window time.Duration) { |
|||
e.expiration = expiration |
|||
if window > 0 { |
|||
e.expiration = e.expiration.Add(-window) |
|||
} |
|||
} |
|||
|
|||
// IsExpired returns if the credentials are expired.
|
|||
func (e *Expiry) IsExpired() bool { |
|||
if e.CurrentTime == nil { |
|||
e.CurrentTime = time.Now |
|||
} |
|||
return e.expiration.Before(e.CurrentTime()) |
|||
} |
|||
|
|||
// A Credentials provides synchronous safe retrieval of AWS credentials Value.
|
|||
// Credentials will cache the credentials value until they expire. Once the value
|
|||
// expires the next Get will attempt to retrieve valid credentials.
|
|||
//
|
|||
// Credentials is safe to use across multiple goroutines and will manage the
|
|||
// synchronous state so the Providers do not need to implement their own
|
|||
// synchronization.
|
|||
//
|
|||
// The first Credentials.Get() will always call Provider.Retrieve() to get the
|
|||
// first instance of the credentials Value. All calls to Get() after that
|
|||
// will return the cached credentials Value until IsExpired() returns true.
|
|||
type Credentials struct { |
|||
creds Value |
|||
forceRefresh bool |
|||
m sync.Mutex |
|||
|
|||
provider Provider |
|||
} |
|||
|
|||
// NewCredentials returns a pointer to a new Credentials with the provider set.
|
|||
func NewCredentials(provider Provider) *Credentials { |
|||
return &Credentials{ |
|||
provider: provider, |
|||
forceRefresh: true, |
|||
} |
|||
} |
|||
|
|||
// Get returns the credentials value, or error if the credentials Value failed
|
|||
// to be retrieved.
|
|||
//
|
|||
// Will return the cached credentials Value if it has not expired. If the
|
|||
// credentials Value has expired the Provider's Retrieve() will be called
|
|||
// to refresh the credentials.
|
|||
//
|
|||
// If Credentials.Expire() was called the credentials Value will be force
|
|||
// expired, and the next call to Get() will cause them to be refreshed.
|
|||
func (c *Credentials) Get() (Value, error) { |
|||
c.m.Lock() |
|||
defer c.m.Unlock() |
|||
|
|||
if c.isExpired() { |
|||
creds, err := c.provider.Retrieve() |
|||
if err != nil { |
|||
return Value{}, err |
|||
} |
|||
c.creds = creds |
|||
c.forceRefresh = false |
|||
} |
|||
|
|||
return c.creds, nil |
|||
} |
|||
|
|||
// Expire expires the credentials and forces them to be retrieved on the
|
|||
// next call to Get().
|
|||
//
|
|||
// This will override the Provider's expired state, and force Credentials
|
|||
// to call the Provider's Retrieve().
|
|||
func (c *Credentials) Expire() { |
|||
c.m.Lock() |
|||
defer c.m.Unlock() |
|||
|
|||
c.forceRefresh = true |
|||
} |
|||
|
|||
// IsExpired returns if the credentials are no longer valid, and need
|
|||
// to be retrieved.
|
|||
//
|
|||
// If the Credentials were forced to be expired with Expire() this will
|
|||
// reflect that override.
|
|||
func (c *Credentials) IsExpired() bool { |
|||
c.m.Lock() |
|||
defer c.m.Unlock() |
|||
|
|||
return c.isExpired() |
|||
} |
|||
|
|||
// isExpired helper method wrapping the definition of expired credentials.
|
|||
func (c *Credentials) isExpired() bool { |
|||
return c.forceRefresh || c.provider.IsExpired() |
|||
} |
@ -0,0 +1,178 @@ |
|||
package ec2rolecreds |
|||
|
|||
import ( |
|||
"bufio" |
|||
"encoding/json" |
|||
"fmt" |
|||
"path" |
|||
"strings" |
|||
"time" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
"github.com/aws/aws-sdk-go/aws/client" |
|||
"github.com/aws/aws-sdk-go/aws/credentials" |
|||
"github.com/aws/aws-sdk-go/aws/ec2metadata" |
|||
) |
|||
|
|||
// ProviderName provides a name of EC2Role provider
|
|||
const ProviderName = "EC2RoleProvider" |
|||
|
|||
// A EC2RoleProvider retrieves credentials from the EC2 service, and keeps track if
|
|||
// those credentials are expired.
|
|||
//
|
|||
// Example how to configure the EC2RoleProvider with custom http Client, Endpoint
|
|||
// or ExpiryWindow
|
|||
//
|
|||
// p := &ec2rolecreds.EC2RoleProvider{
|
|||
// // Pass in a custom timeout to be used when requesting
|
|||
// // IAM EC2 Role credentials.
|
|||
// Client: ec2metadata.New(sess, aws.Config{
|
|||
// HTTPClient: &http.Client{Timeout: 10 * time.Second},
|
|||
// }),
|
|||
//
|
|||
// // Do not use early expiry of credentials. If a non zero value is
|
|||
// // specified the credentials will be expired early
|
|||
// ExpiryWindow: 0,
|
|||
// }
|
|||
type EC2RoleProvider struct { |
|||
credentials.Expiry |
|||
|
|||
// Required EC2Metadata client to use when connecting to EC2 metadata service.
|
|||
Client *ec2metadata.EC2Metadata |
|||
|
|||
// ExpiryWindow will allow the credentials to trigger refreshing prior to
|
|||
// the credentials actually expiring. This is beneficial so race conditions
|
|||
// with expiring credentials do not cause request to fail unexpectedly
|
|||
// due to ExpiredTokenException exceptions.
|
|||
//
|
|||
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
|
|||
// 10 seconds before the credentials are actually expired.
|
|||
//
|
|||
// If ExpiryWindow is 0 or less it will be ignored.
|
|||
ExpiryWindow time.Duration |
|||
} |
|||
|
|||
// NewCredentials returns a pointer to a new Credentials object wrapping
|
|||
// the EC2RoleProvider. Takes a ConfigProvider to create a EC2Metadata client.
|
|||
// The ConfigProvider is satisfied by the session.Session type.
|
|||
func NewCredentials(c client.ConfigProvider, options ...func(*EC2RoleProvider)) *credentials.Credentials { |
|||
p := &EC2RoleProvider{ |
|||
Client: ec2metadata.New(c), |
|||
} |
|||
|
|||
for _, option := range options { |
|||
option(p) |
|||
} |
|||
|
|||
return credentials.NewCredentials(p) |
|||
} |
|||
|
|||
// NewCredentialsWithClient returns a pointer to a new Credentials object wrapping
|
|||
// the EC2RoleProvider. Takes a EC2Metadata client to use when connecting to EC2
|
|||
// metadata service.
|
|||
func NewCredentialsWithClient(client *ec2metadata.EC2Metadata, options ...func(*EC2RoleProvider)) *credentials.Credentials { |
|||
p := &EC2RoleProvider{ |
|||
Client: client, |
|||
} |
|||
|
|||
for _, option := range options { |
|||
option(p) |
|||
} |
|||
|
|||
return credentials.NewCredentials(p) |
|||
} |
|||
|
|||
// Retrieve retrieves credentials from the EC2 service.
|
|||
// Error will be returned if the request fails, or unable to extract
|
|||
// the desired credentials.
|
|||
func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) { |
|||
credsList, err := requestCredList(m.Client) |
|||
if err != nil { |
|||
return credentials.Value{ProviderName: ProviderName}, err |
|||
} |
|||
|
|||
if len(credsList) == 0 { |
|||
return credentials.Value{ProviderName: ProviderName}, awserr.New("EmptyEC2RoleList", "empty EC2 Role list", nil) |
|||
} |
|||
credsName := credsList[0] |
|||
|
|||
roleCreds, err := requestCred(m.Client, credsName) |
|||
if err != nil { |
|||
return credentials.Value{ProviderName: ProviderName}, err |
|||
} |
|||
|
|||
m.SetExpiration(roleCreds.Expiration, m.ExpiryWindow) |
|||
|
|||
return credentials.Value{ |
|||
AccessKeyID: roleCreds.AccessKeyID, |
|||
SecretAccessKey: roleCreds.SecretAccessKey, |
|||
SessionToken: roleCreds.Token, |
|||
ProviderName: ProviderName, |
|||
}, nil |
|||
} |
|||
|
|||
// A ec2RoleCredRespBody provides the shape for unmarshalling credential
|
|||
// request responses.
|
|||
type ec2RoleCredRespBody struct { |
|||
// Success State
|
|||
Expiration time.Time |
|||
AccessKeyID string |
|||
SecretAccessKey string |
|||
Token string |
|||
|
|||
// Error state
|
|||
Code string |
|||
Message string |
|||
} |
|||
|
|||
const iamSecurityCredsPath = "/iam/security-credentials" |
|||
|
|||
// requestCredList requests a list of credentials from the EC2 service.
|
|||
// If there are no credentials, or there is an error making or receiving the request
|
|||
func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) { |
|||
resp, err := client.GetMetadata(iamSecurityCredsPath) |
|||
if err != nil { |
|||
return nil, awserr.New("EC2RoleRequestError", "no EC2 instance role found", err) |
|||
} |
|||
|
|||
credsList := []string{} |
|||
s := bufio.NewScanner(strings.NewReader(resp)) |
|||
for s.Scan() { |
|||
credsList = append(credsList, s.Text()) |
|||
} |
|||
|
|||
if err := s.Err(); err != nil { |
|||
return nil, awserr.New("SerializationError", "failed to read EC2 instance role from metadata service", err) |
|||
} |
|||
|
|||
return credsList, nil |
|||
} |
|||
|
|||
// requestCred requests the credentials for a specific credentials from the EC2 service.
|
|||
//
|
|||
// If the credentials cannot be found, or there is an error reading the response
|
|||
// and error will be returned.
|
|||
func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCredRespBody, error) { |
|||
resp, err := client.GetMetadata(path.Join(iamSecurityCredsPath, credsName)) |
|||
if err != nil { |
|||
return ec2RoleCredRespBody{}, |
|||
awserr.New("EC2RoleRequestError", |
|||
fmt.Sprintf("failed to get %s EC2 instance role credentials", credsName), |
|||
err) |
|||
} |
|||
|
|||
respCreds := ec2RoleCredRespBody{} |
|||
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil { |
|||
return ec2RoleCredRespBody{}, |
|||
awserr.New("SerializationError", |
|||
fmt.Sprintf("failed to decode %s EC2 instance role credentials", credsName), |
|||
err) |
|||
} |
|||
|
|||
if respCreds.Code != "Success" { |
|||
// If an error code was returned something failed requesting the role.
|
|||
return ec2RoleCredRespBody{}, awserr.New(respCreds.Code, respCreds.Message, nil) |
|||
} |
|||
|
|||
return respCreds, nil |
|||
} |
@ -0,0 +1,191 @@ |
|||
// Package endpointcreds provides support for retrieving credentials from an
|
|||
// arbitrary HTTP endpoint.
|
|||
//
|
|||
// The credentials endpoint Provider can receive both static and refreshable
|
|||
// credentials that will expire. Credentials are static when an "Expiration"
|
|||
// value is not provided in the endpoint's response.
|
|||
//
|
|||
// Static credentials will never expire once they have been retrieved. The format
|
|||
// of the static credentials response:
|
|||
// {
|
|||
// "AccessKeyId" : "MUA...",
|
|||
// "SecretAccessKey" : "/7PC5om....",
|
|||
// }
|
|||
//
|
|||
// Refreshable credentials will expire within the "ExpiryWindow" of the Expiration
|
|||
// value in the response. The format of the refreshable credentials response:
|
|||
// {
|
|||
// "AccessKeyId" : "MUA...",
|
|||
// "SecretAccessKey" : "/7PC5om....",
|
|||
// "Token" : "AQoDY....=",
|
|||
// "Expiration" : "2016-02-25T06:03:31Z"
|
|||
// }
|
|||
//
|
|||
// Errors should be returned in the following format and only returned with 400
|
|||
// or 500 HTTP status codes.
|
|||
// {
|
|||
// "code": "ErrorCode",
|
|||
// "message": "Helpful error message."
|
|||
// }
|
|||
package endpointcreds |
|||
|
|||
import ( |
|||
"encoding/json" |
|||
"time" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws" |
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
"github.com/aws/aws-sdk-go/aws/client" |
|||
"github.com/aws/aws-sdk-go/aws/client/metadata" |
|||
"github.com/aws/aws-sdk-go/aws/credentials" |
|||
"github.com/aws/aws-sdk-go/aws/request" |
|||
) |
|||
|
|||
// ProviderName is the name of the credentials provider.
|
|||
const ProviderName = `CredentialsEndpointProvider` |
|||
|
|||
// Provider satisfies the credentials.Provider interface, and is a client to
|
|||
// retrieve credentials from an arbitrary endpoint.
|
|||
type Provider struct { |
|||
staticCreds bool |
|||
credentials.Expiry |
|||
|
|||
// Requires a AWS Client to make HTTP requests to the endpoint with.
|
|||
// the Endpoint the request will be made to is provided by the aws.Config's
|
|||
// Endpoint value.
|
|||
Client *client.Client |
|||
|
|||
// ExpiryWindow will allow the credentials to trigger refreshing prior to
|
|||
// the credentials actually expiring. This is beneficial so race conditions
|
|||
// with expiring credentials do not cause request to fail unexpectedly
|
|||
// due to ExpiredTokenException exceptions.
|
|||
//
|
|||
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
|
|||
// 10 seconds before the credentials are actually expired.
|
|||
//
|
|||
// If ExpiryWindow is 0 or less it will be ignored.
|
|||
ExpiryWindow time.Duration |
|||
} |
|||
|
|||
// NewProviderClient returns a credentials Provider for retrieving AWS credentials
|
|||
// from arbitrary endpoint.
|
|||
func NewProviderClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) credentials.Provider { |
|||
p := &Provider{ |
|||
Client: client.New( |
|||
cfg, |
|||
metadata.ClientInfo{ |
|||
ServiceName: "CredentialsEndpoint", |
|||
Endpoint: endpoint, |
|||
}, |
|||
handlers, |
|||
), |
|||
} |
|||
|
|||
p.Client.Handlers.Unmarshal.PushBack(unmarshalHandler) |
|||
p.Client.Handlers.UnmarshalError.PushBack(unmarshalError) |
|||
p.Client.Handlers.Validate.Clear() |
|||
p.Client.Handlers.Validate.PushBack(validateEndpointHandler) |
|||
|
|||
for _, option := range options { |
|||
option(p) |
|||
} |
|||
|
|||
return p |
|||
} |
|||
|
|||
// NewCredentialsClient returns a Credentials wrapper for retrieving credentials
|
|||
// from an arbitrary endpoint concurrently. The client will request the
|
|||
func NewCredentialsClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) *credentials.Credentials { |
|||
return credentials.NewCredentials(NewProviderClient(cfg, handlers, endpoint, options...)) |
|||
} |
|||
|
|||
// IsExpired returns true if the credentials retrieved are expired, or not yet
|
|||
// retrieved.
|
|||
func (p *Provider) IsExpired() bool { |
|||
if p.staticCreds { |
|||
return false |
|||
} |
|||
return p.Expiry.IsExpired() |
|||
} |
|||
|
|||
// Retrieve will attempt to request the credentials from the endpoint the Provider
|
|||
// was configured for. And error will be returned if the retrieval fails.
|
|||
func (p *Provider) Retrieve() (credentials.Value, error) { |
|||
resp, err := p.getCredentials() |
|||
if err != nil { |
|||
return credentials.Value{ProviderName: ProviderName}, |
|||
awserr.New("CredentialsEndpointError", "failed to load credentials", err) |
|||
} |
|||
|
|||
if resp.Expiration != nil { |
|||
p.SetExpiration(*resp.Expiration, p.ExpiryWindow) |
|||
} else { |
|||
p.staticCreds = true |
|||
} |
|||
|
|||
return credentials.Value{ |
|||
AccessKeyID: resp.AccessKeyID, |
|||
SecretAccessKey: resp.SecretAccessKey, |
|||
SessionToken: resp.Token, |
|||
ProviderName: ProviderName, |
|||
}, nil |
|||
} |
|||
|
|||
type getCredentialsOutput struct { |
|||
Expiration *time.Time |
|||
AccessKeyID string |
|||
SecretAccessKey string |
|||
Token string |
|||
} |
|||
|
|||
type errorOutput struct { |
|||
Code string `json:"code"` |
|||
Message string `json:"message"` |
|||
} |
|||
|
|||
func (p *Provider) getCredentials() (*getCredentialsOutput, error) { |
|||
op := &request.Operation{ |
|||
Name: "GetCredentials", |
|||
HTTPMethod: "GET", |
|||
} |
|||
|
|||
out := &getCredentialsOutput{} |
|||
req := p.Client.NewRequest(op, nil, out) |
|||
req.HTTPRequest.Header.Set("Accept", "application/json") |
|||
|
|||
return out, req.Send() |
|||
} |
|||
|
|||
func validateEndpointHandler(r *request.Request) { |
|||
if len(r.ClientInfo.Endpoint) == 0 { |
|||
r.Error = aws.ErrMissingEndpoint |
|||
} |
|||
} |
|||
|
|||
func unmarshalHandler(r *request.Request) { |
|||
defer r.HTTPResponse.Body.Close() |
|||
|
|||
out := r.Data.(*getCredentialsOutput) |
|||
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil { |
|||
r.Error = awserr.New("SerializationError", |
|||
"failed to decode endpoint credentials", |
|||
err, |
|||
) |
|||
} |
|||
} |
|||
|
|||
func unmarshalError(r *request.Request) { |
|||
defer r.HTTPResponse.Body.Close() |
|||
|
|||
var errOut errorOutput |
|||
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&errOut); err != nil { |
|||
r.Error = awserr.New("SerializationError", |
|||
"failed to decode endpoint credentials", |
|||
err, |
|||
) |
|||
} |
|||
|
|||
// Response body format is not consistent between metadata endpoints.
|
|||
// Grab the error message as a string and include that as the source error
|
|||
r.Error = awserr.New(errOut.Code, errOut.Message, nil) |
|||
} |
@ -0,0 +1,77 @@ |
|||
package credentials |
|||
|
|||
import ( |
|||
"os" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
) |
|||
|
|||
// EnvProviderName provides a name of Env provider
|
|||
const EnvProviderName = "EnvProvider" |
|||
|
|||
var ( |
|||
// ErrAccessKeyIDNotFound is returned when the AWS Access Key ID can't be
|
|||
// found in the process's environment.
|
|||
//
|
|||
// @readonly
|
|||
ErrAccessKeyIDNotFound = awserr.New("EnvAccessKeyNotFound", "AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment", nil) |
|||
|
|||
// ErrSecretAccessKeyNotFound is returned when the AWS Secret Access Key
|
|||
// can't be found in the process's environment.
|
|||
//
|
|||
// @readonly
|
|||
ErrSecretAccessKeyNotFound = awserr.New("EnvSecretNotFound", "AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment", nil) |
|||
) |
|||
|
|||
// A EnvProvider retrieves credentials from the environment variables of the
|
|||
// running process. Environment credentials never expire.
|
|||
//
|
|||
// Environment variables used:
|
|||
//
|
|||
// * Access Key ID: AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY
|
|||
// * Secret Access Key: AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY
|
|||
type EnvProvider struct { |
|||
retrieved bool |
|||
} |
|||
|
|||
// NewEnvCredentials returns a pointer to a new Credentials object
|
|||
// wrapping the environment variable provider.
|
|||
func NewEnvCredentials() *Credentials { |
|||
return NewCredentials(&EnvProvider{}) |
|||
} |
|||
|
|||
// Retrieve retrieves the keys from the environment.
|
|||
func (e *EnvProvider) Retrieve() (Value, error) { |
|||
e.retrieved = false |
|||
|
|||
id := os.Getenv("AWS_ACCESS_KEY_ID") |
|||
if id == "" { |
|||
id = os.Getenv("AWS_ACCESS_KEY") |
|||
} |
|||
|
|||
secret := os.Getenv("AWS_SECRET_ACCESS_KEY") |
|||
if secret == "" { |
|||
secret = os.Getenv("AWS_SECRET_KEY") |
|||
} |
|||
|
|||
if id == "" { |
|||
return Value{ProviderName: EnvProviderName}, ErrAccessKeyIDNotFound |
|||
} |
|||
|
|||
if secret == "" { |
|||
return Value{ProviderName: EnvProviderName}, ErrSecretAccessKeyNotFound |
|||
} |
|||
|
|||
e.retrieved = true |
|||
return Value{ |
|||
AccessKeyID: id, |
|||
SecretAccessKey: secret, |
|||
SessionToken: os.Getenv("AWS_SESSION_TOKEN"), |
|||
ProviderName: EnvProviderName, |
|||
}, nil |
|||
} |
|||
|
|||
// IsExpired returns if the credentials have been retrieved.
|
|||
func (e *EnvProvider) IsExpired() bool { |
|||
return !e.retrieved |
|||
} |
@ -0,0 +1,12 @@ |
|||
[default] |
|||
aws_access_key_id = accessKey |
|||
aws_secret_access_key = secret |
|||
aws_session_token = token |
|||
|
|||
[no_token] |
|||
aws_access_key_id = accessKey |
|||
aws_secret_access_key = secret |
|||
|
|||
[with_colon] |
|||
aws_access_key_id: accessKey |
|||
aws_secret_access_key: secret |
@ -0,0 +1,151 @@ |
|||
package credentials |
|||
|
|||
import ( |
|||
"fmt" |
|||
"os" |
|||
"path/filepath" |
|||
|
|||
"github.com/go-ini/ini" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
) |
|||
|
|||
// SharedCredsProviderName provides a name of SharedCreds provider
|
|||
const SharedCredsProviderName = "SharedCredentialsProvider" |
|||
|
|||
var ( |
|||
// ErrSharedCredentialsHomeNotFound is emitted when the user directory cannot be found.
|
|||
//
|
|||
// @readonly
|
|||
ErrSharedCredentialsHomeNotFound = awserr.New("UserHomeNotFound", "user home directory not found.", nil) |
|||
) |
|||
|
|||
// A SharedCredentialsProvider retrieves credentials from the current user's home
|
|||
// directory, and keeps track if those credentials are expired.
|
|||
//
|
|||
// Profile ini file example: $HOME/.aws/credentials
|
|||
type SharedCredentialsProvider struct { |
|||
// Path to the shared credentials file.
|
|||
//
|
|||
// If empty will look for "AWS_SHARED_CREDENTIALS_FILE" env variable. If the
|
|||
// env value is empty will default to current user's home directory.
|
|||
// Linux/OSX: "$HOME/.aws/credentials"
|
|||
// Windows: "%USERPROFILE%\.aws\credentials"
|
|||
Filename string |
|||
|
|||
// AWS Profile to extract credentials from the shared credentials file. If empty
|
|||
// will default to environment variable "AWS_PROFILE" or "default" if
|
|||
// environment variable is also not set.
|
|||
Profile string |
|||
|
|||
// retrieved states if the credentials have been successfully retrieved.
|
|||
retrieved bool |
|||
} |
|||
|
|||
// NewSharedCredentials returns a pointer to a new Credentials object
|
|||
// wrapping the Profile file provider.
|
|||
func NewSharedCredentials(filename, profile string) *Credentials { |
|||
return NewCredentials(&SharedCredentialsProvider{ |
|||
Filename: filename, |
|||
Profile: profile, |
|||
}) |
|||
} |
|||
|
|||
// Retrieve reads and extracts the shared credentials from the current
|
|||
// users home directory.
|
|||
func (p *SharedCredentialsProvider) Retrieve() (Value, error) { |
|||
p.retrieved = false |
|||
|
|||
filename, err := p.filename() |
|||
if err != nil { |
|||
return Value{ProviderName: SharedCredsProviderName}, err |
|||
} |
|||
|
|||
creds, err := loadProfile(filename, p.profile()) |
|||
if err != nil { |
|||
return Value{ProviderName: SharedCredsProviderName}, err |
|||
} |
|||
|
|||
p.retrieved = true |
|||
return creds, nil |
|||
} |
|||
|
|||
// IsExpired returns if the shared credentials have expired.
|
|||
func (p *SharedCredentialsProvider) IsExpired() bool { |
|||
return !p.retrieved |
|||
} |
|||
|
|||
// loadProfiles loads from the file pointed to by shared credentials filename for profile.
|
|||
// The credentials retrieved from the profile will be returned or error. Error will be
|
|||
// returned if it fails to read from the file, or the data is invalid.
|
|||
func loadProfile(filename, profile string) (Value, error) { |
|||
config, err := ini.Load(filename) |
|||
if err != nil { |
|||
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to load shared credentials file", err) |
|||
} |
|||
iniProfile, err := config.GetSection(profile) |
|||
if err != nil { |
|||
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to get profile", err) |
|||
} |
|||
|
|||
id, err := iniProfile.GetKey("aws_access_key_id") |
|||
if err != nil { |
|||
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsAccessKey", |
|||
fmt.Sprintf("shared credentials %s in %s did not contain aws_access_key_id", profile, filename), |
|||
err) |
|||
} |
|||
|
|||
secret, err := iniProfile.GetKey("aws_secret_access_key") |
|||
if err != nil { |
|||
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsSecret", |
|||
fmt.Sprintf("shared credentials %s in %s did not contain aws_secret_access_key", profile, filename), |
|||
nil) |
|||
} |
|||
|
|||
// Default to empty string if not found
|
|||
token := iniProfile.Key("aws_session_token") |
|||
|
|||
return Value{ |
|||
AccessKeyID: id.String(), |
|||
SecretAccessKey: secret.String(), |
|||
SessionToken: token.String(), |
|||
ProviderName: SharedCredsProviderName, |
|||
}, nil |
|||
} |
|||
|
|||
// filename returns the filename to use to read AWS shared credentials.
|
|||
//
|
|||
// Will return an error if the user's home directory path cannot be found.
|
|||
func (p *SharedCredentialsProvider) filename() (string, error) { |
|||
if p.Filename == "" { |
|||
if p.Filename = os.Getenv("AWS_SHARED_CREDENTIALS_FILE"); p.Filename != "" { |
|||
return p.Filename, nil |
|||
} |
|||
|
|||
homeDir := os.Getenv("HOME") // *nix
|
|||
if homeDir == "" { // Windows
|
|||
homeDir = os.Getenv("USERPROFILE") |
|||
} |
|||
if homeDir == "" { |
|||
return "", ErrSharedCredentialsHomeNotFound |
|||
} |
|||
|
|||
p.Filename = filepath.Join(homeDir, ".aws", "credentials") |
|||
} |
|||
|
|||
return p.Filename, nil |
|||
} |
|||
|
|||
// profile returns the AWS shared credentials profile. If empty will read
|
|||
// environment variable "AWS_PROFILE". If that is not set profile will
|
|||
// return "default".
|
|||
func (p *SharedCredentialsProvider) profile() string { |
|||
if p.Profile == "" { |
|||
p.Profile = os.Getenv("AWS_PROFILE") |
|||
} |
|||
if p.Profile == "" { |
|||
p.Profile = "default" |
|||
} |
|||
|
|||
return p.Profile |
|||
} |
@ -0,0 +1,57 @@ |
|||
package credentials |
|||
|
|||
import ( |
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
) |
|||
|
|||
// StaticProviderName provides a name of Static provider
|
|||
const StaticProviderName = "StaticProvider" |
|||
|
|||
var ( |
|||
// ErrStaticCredentialsEmpty is emitted when static credentials are empty.
|
|||
//
|
|||
// @readonly
|
|||
ErrStaticCredentialsEmpty = awserr.New("EmptyStaticCreds", "static credentials are empty", nil) |
|||
) |
|||
|
|||
// A StaticProvider is a set of credentials which are set programmatically,
|
|||
// and will never expire.
|
|||
type StaticProvider struct { |
|||
Value |
|||
} |
|||
|
|||
// NewStaticCredentials returns a pointer to a new Credentials object
|
|||
// wrapping a static credentials value provider.
|
|||
func NewStaticCredentials(id, secret, token string) *Credentials { |
|||
return NewCredentials(&StaticProvider{Value: Value{ |
|||
AccessKeyID: id, |
|||
SecretAccessKey: secret, |
|||
SessionToken: token, |
|||
}}) |
|||
} |
|||
|
|||
// NewStaticCredentialsFromCreds returns a pointer to a new Credentials object
|
|||
// wrapping the static credentials value provide. Same as NewStaticCredentials
|
|||
// but takes the creds Value instead of individual fields
|
|||
func NewStaticCredentialsFromCreds(creds Value) *Credentials { |
|||
return NewCredentials(&StaticProvider{Value: creds}) |
|||
} |
|||
|
|||
// Retrieve returns the credentials or error if the credentials are invalid.
|
|||
func (s *StaticProvider) Retrieve() (Value, error) { |
|||
if s.AccessKeyID == "" || s.SecretAccessKey == "" { |
|||
return Value{ProviderName: StaticProviderName}, ErrStaticCredentialsEmpty |
|||
} |
|||
|
|||
if len(s.Value.ProviderName) == 0 { |
|||
s.Value.ProviderName = StaticProviderName |
|||
} |
|||
return s.Value, nil |
|||
} |
|||
|
|||
// IsExpired returns if the credentials are expired.
|
|||
//
|
|||
// For StaticProvider, the credentials never expired.
|
|||
func (s *StaticProvider) IsExpired() bool { |
|||
return false |
|||
} |
@ -0,0 +1,161 @@ |
|||
// Package stscreds are credential Providers to retrieve STS AWS credentials.
|
|||
//
|
|||
// STS provides multiple ways to retrieve credentials which can be used when making
|
|||
// future AWS service API operation calls.
|
|||
package stscreds |
|||
|
|||
import ( |
|||
"fmt" |
|||
"time" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws" |
|||
"github.com/aws/aws-sdk-go/aws/client" |
|||
"github.com/aws/aws-sdk-go/aws/credentials" |
|||
"github.com/aws/aws-sdk-go/service/sts" |
|||
) |
|||
|
|||
// ProviderName provides a name of AssumeRole provider
|
|||
const ProviderName = "AssumeRoleProvider" |
|||
|
|||
// AssumeRoler represents the minimal subset of the STS client API used by this provider.
|
|||
type AssumeRoler interface { |
|||
AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) |
|||
} |
|||
|
|||
// DefaultDuration is the default amount of time in minutes that the credentials
|
|||
// will be valid for.
|
|||
var DefaultDuration = time.Duration(15) * time.Minute |
|||
|
|||
// AssumeRoleProvider retrieves temporary credentials from the STS service, and
|
|||
// keeps track of their expiration time. This provider must be used explicitly,
|
|||
// as it is not included in the credentials chain.
|
|||
type AssumeRoleProvider struct { |
|||
credentials.Expiry |
|||
|
|||
// STS client to make assume role request with.
|
|||
Client AssumeRoler |
|||
|
|||
// Role to be assumed.
|
|||
RoleARN string |
|||
|
|||
// Session name, if you wish to reuse the credentials elsewhere.
|
|||
RoleSessionName string |
|||
|
|||
// Expiry duration of the STS credentials. Defaults to 15 minutes if not set.
|
|||
Duration time.Duration |
|||
|
|||
// Optional ExternalID to pass along, defaults to nil if not set.
|
|||
ExternalID *string |
|||
|
|||
// The policy plain text must be 2048 bytes or shorter. However, an internal
|
|||
// conversion compresses it into a packed binary format with a separate limit.
|
|||
// The PackedPolicySize response element indicates by percentage how close to
|
|||
// the upper size limit the policy is, with 100% equaling the maximum allowed
|
|||
// size.
|
|||
Policy *string |
|||
|
|||
// The identification number of the MFA device that is associated with the user
|
|||
// who is making the AssumeRole call. Specify this value if the trust policy
|
|||
// of the role being assumed includes a condition that requires MFA authentication.
|
|||
// The value is either the serial number for a hardware device (such as GAHT12345678)
|
|||
// or an Amazon Resource Name (ARN) for a virtual device (such as arn:aws:iam::123456789012:mfa/user).
|
|||
SerialNumber *string |
|||
|
|||
// The value provided by the MFA device, if the trust policy of the role being
|
|||
// assumed requires MFA (that is, if the policy includes a condition that tests
|
|||
// for MFA). If the role being assumed requires MFA and if the TokenCode value
|
|||
// is missing or expired, the AssumeRole call returns an "access denied" error.
|
|||
TokenCode *string |
|||
|
|||
// ExpiryWindow will allow the credentials to trigger refreshing prior to
|
|||
// the credentials actually expiring. This is beneficial so race conditions
|
|||
// with expiring credentials do not cause request to fail unexpectedly
|
|||
// due to ExpiredTokenException exceptions.
|
|||
//
|
|||
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
|
|||
// 10 seconds before the credentials are actually expired.
|
|||
//
|
|||
// If ExpiryWindow is 0 or less it will be ignored.
|
|||
ExpiryWindow time.Duration |
|||
} |
|||
|
|||
// NewCredentials returns a pointer to a new Credentials object wrapping the
|
|||
// AssumeRoleProvider. The credentials will expire every 15 minutes and the
|
|||
// role will be named after a nanosecond timestamp of this operation.
|
|||
//
|
|||
// Takes a Config provider to create the STS client. The ConfigProvider is
|
|||
// satisfied by the session.Session type.
|
|||
func NewCredentials(c client.ConfigProvider, roleARN string, options ...func(*AssumeRoleProvider)) *credentials.Credentials { |
|||
p := &AssumeRoleProvider{ |
|||
Client: sts.New(c), |
|||
RoleARN: roleARN, |
|||
Duration: DefaultDuration, |
|||
} |
|||
|
|||
for _, option := range options { |
|||
option(p) |
|||
} |
|||
|
|||
return credentials.NewCredentials(p) |
|||
} |
|||
|
|||
// NewCredentialsWithClient returns a pointer to a new Credentials object wrapping the
|
|||
// AssumeRoleProvider. The credentials will expire every 15 minutes and the
|
|||
// role will be named after a nanosecond timestamp of this operation.
|
|||
//
|
|||
// Takes an AssumeRoler which can be satisfiede by the STS client.
|
|||
func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(*AssumeRoleProvider)) *credentials.Credentials { |
|||
p := &AssumeRoleProvider{ |
|||
Client: svc, |
|||
RoleARN: roleARN, |
|||
Duration: DefaultDuration, |
|||
} |
|||
|
|||
for _, option := range options { |
|||
option(p) |
|||
} |
|||
|
|||
return credentials.NewCredentials(p) |
|||
} |
|||
|
|||
// Retrieve generates a new set of temporary credentials using STS.
|
|||
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) { |
|||
|
|||
// Apply defaults where parameters are not set.
|
|||
if p.RoleSessionName == "" { |
|||
// Try to work out a role name that will hopefully end up unique.
|
|||
p.RoleSessionName = fmt.Sprintf("%d", time.Now().UTC().UnixNano()) |
|||
} |
|||
if p.Duration == 0 { |
|||
// Expire as often as AWS permits.
|
|||
p.Duration = DefaultDuration |
|||
} |
|||
input := &sts.AssumeRoleInput{ |
|||
DurationSeconds: aws.Int64(int64(p.Duration / time.Second)), |
|||
RoleArn: aws.String(p.RoleARN), |
|||
RoleSessionName: aws.String(p.RoleSessionName), |
|||
ExternalId: p.ExternalID, |
|||
} |
|||
if p.Policy != nil { |
|||
input.Policy = p.Policy |
|||
} |
|||
if p.SerialNumber != nil && p.TokenCode != nil { |
|||
input.SerialNumber = p.SerialNumber |
|||
input.TokenCode = p.TokenCode |
|||
} |
|||
roleOutput, err := p.Client.AssumeRole(input) |
|||
|
|||
if err != nil { |
|||
return credentials.Value{ProviderName: ProviderName}, err |
|||
} |
|||
|
|||
// We will proactively generate new credentials before they expire.
|
|||
p.SetExpiration(*roleOutput.Credentials.Expiration, p.ExpiryWindow) |
|||
|
|||
return credentials.Value{ |
|||
AccessKeyID: *roleOutput.Credentials.AccessKeyId, |
|||
SecretAccessKey: *roleOutput.Credentials.SecretAccessKey, |
|||
SessionToken: *roleOutput.Credentials.SessionToken, |
|||
ProviderName: ProviderName, |
|||
}, nil |
|||
} |
@ -0,0 +1,130 @@ |
|||
// Package defaults is a collection of helpers to retrieve the SDK's default
|
|||
// configuration and handlers.
|
|||
//
|
|||
// Generally this package shouldn't be used directly, but session.Session
|
|||
// instead. This package is useful when you need to reset the defaults
|
|||
// of a session or service client to the SDK defaults before setting
|
|||
// additional parameters.
|
|||
package defaults |
|||
|
|||
import ( |
|||
"fmt" |
|||
"net/http" |
|||
"os" |
|||
"time" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws" |
|||
"github.com/aws/aws-sdk-go/aws/corehandlers" |
|||
"github.com/aws/aws-sdk-go/aws/credentials" |
|||
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" |
|||
"github.com/aws/aws-sdk-go/aws/credentials/endpointcreds" |
|||
"github.com/aws/aws-sdk-go/aws/ec2metadata" |
|||
"github.com/aws/aws-sdk-go/aws/request" |
|||
"github.com/aws/aws-sdk-go/private/endpoints" |
|||
) |
|||
|
|||
// A Defaults provides a collection of default values for SDK clients.
|
|||
type Defaults struct { |
|||
Config *aws.Config |
|||
Handlers request.Handlers |
|||
} |
|||
|
|||
// Get returns the SDK's default values with Config and handlers pre-configured.
|
|||
func Get() Defaults { |
|||
cfg := Config() |
|||
handlers := Handlers() |
|||
cfg.Credentials = CredChain(cfg, handlers) |
|||
|
|||
return Defaults{ |
|||
Config: cfg, |
|||
Handlers: handlers, |
|||
} |
|||
} |
|||
|
|||
// Config returns the default configuration without credentials.
|
|||
// To retrieve a config with credentials also included use
|
|||
// `defaults.Get().Config` instead.
|
|||
//
|
|||
// Generally you shouldn't need to use this method directly, but
|
|||
// is available if you need to reset the configuration of an
|
|||
// existing service client or session.
|
|||
func Config() *aws.Config { |
|||
return aws.NewConfig(). |
|||
WithCredentials(credentials.AnonymousCredentials). |
|||
WithRegion(os.Getenv("AWS_REGION")). |
|||
WithHTTPClient(http.DefaultClient). |
|||
WithMaxRetries(aws.UseServiceDefaultRetries). |
|||
WithLogger(aws.NewDefaultLogger()). |
|||
WithLogLevel(aws.LogOff). |
|||
WithSleepDelay(time.Sleep) |
|||
} |
|||
|
|||
// Handlers returns the default request handlers.
|
|||
//
|
|||
// Generally you shouldn't need to use this method directly, but
|
|||
// is available if you need to reset the request handlers of an
|
|||
// existing service client or session.
|
|||
func Handlers() request.Handlers { |
|||
var handlers request.Handlers |
|||
|
|||
handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler) |
|||
handlers.Validate.AfterEachFn = request.HandlerListStopOnError |
|||
handlers.Build.PushBackNamed(corehandlers.SDKVersionUserAgentHandler) |
|||
handlers.Build.AfterEachFn = request.HandlerListStopOnError |
|||
handlers.Sign.PushBackNamed(corehandlers.BuildContentLengthHandler) |
|||
handlers.Send.PushBackNamed(corehandlers.ValidateReqSigHandler) |
|||
handlers.Send.PushBackNamed(corehandlers.SendHandler) |
|||
handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler) |
|||
handlers.ValidateResponse.PushBackNamed(corehandlers.ValidateResponseHandler) |
|||
|
|||
return handlers |
|||
} |
|||
|
|||
// CredChain returns the default credential chain.
|
|||
//
|
|||
// Generally you shouldn't need to use this method directly, but
|
|||
// is available if you need to reset the credentials of an
|
|||
// existing service client or session's Config.
|
|||
func CredChain(cfg *aws.Config, handlers request.Handlers) *credentials.Credentials { |
|||
return credentials.NewCredentials(&credentials.ChainProvider{ |
|||
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors), |
|||
Providers: []credentials.Provider{ |
|||
&credentials.EnvProvider{}, |
|||
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""}, |
|||
RemoteCredProvider(*cfg, handlers), |
|||
}, |
|||
}) |
|||
} |
|||
|
|||
// RemoteCredProvider returns a credenitials provider for the default remote
|
|||
// endpoints such as EC2 or ECS Roles.
|
|||
func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.Provider { |
|||
ecsCredURI := os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") |
|||
|
|||
if len(ecsCredURI) > 0 { |
|||
return ecsCredProvider(cfg, handlers, ecsCredURI) |
|||
} |
|||
|
|||
return ec2RoleProvider(cfg, handlers) |
|||
} |
|||
|
|||
func ecsCredProvider(cfg aws.Config, handlers request.Handlers, uri string) credentials.Provider { |
|||
const host = `169.254.170.2` |
|||
|
|||
return endpointcreds.NewProviderClient(cfg, handlers, |
|||
fmt.Sprintf("http://%s%s", host, uri), |
|||
func(p *endpointcreds.Provider) { |
|||
p.ExpiryWindow = 5 * time.Minute |
|||
}, |
|||
) |
|||
} |
|||
|
|||
func ec2RoleProvider(cfg aws.Config, handlers request.Handlers) credentials.Provider { |
|||
endpoint, signingRegion := endpoints.EndpointForRegion(ec2metadata.ServiceName, |
|||
aws.StringValue(cfg.Region), true, false) |
|||
|
|||
return &ec2rolecreds.EC2RoleProvider{ |
|||
Client: ec2metadata.NewClient(cfg, handlers, endpoint, signingRegion), |
|||
ExpiryWindow: 5 * time.Minute, |
|||
} |
|||
} |
@ -0,0 +1,162 @@ |
|||
package ec2metadata |
|||
|
|||
import ( |
|||
"encoding/json" |
|||
"fmt" |
|||
"net/http" |
|||
"path" |
|||
"strings" |
|||
"time" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
"github.com/aws/aws-sdk-go/aws/request" |
|||
) |
|||
|
|||
// GetMetadata uses the path provided to request information from the EC2
|
|||
// instance metdata service. The content will be returned as a string, or
|
|||
// error if the request failed.
|
|||
func (c *EC2Metadata) GetMetadata(p string) (string, error) { |
|||
op := &request.Operation{ |
|||
Name: "GetMetadata", |
|||
HTTPMethod: "GET", |
|||
HTTPPath: path.Join("/", "meta-data", p), |
|||
} |
|||
|
|||
output := &metadataOutput{} |
|||
req := c.NewRequest(op, nil, output) |
|||
|
|||
return output.Content, req.Send() |
|||
} |
|||
|
|||
// GetUserData returns the userdata that was configured for the service. If
|
|||
// there is no user-data setup for the EC2 instance a "NotFoundError" error
|
|||
// code will be returned.
|
|||
func (c *EC2Metadata) GetUserData() (string, error) { |
|||
op := &request.Operation{ |
|||
Name: "GetUserData", |
|||
HTTPMethod: "GET", |
|||
HTTPPath: path.Join("/", "user-data"), |
|||
} |
|||
|
|||
output := &metadataOutput{} |
|||
req := c.NewRequest(op, nil, output) |
|||
req.Handlers.UnmarshalError.PushBack(func(r *request.Request) { |
|||
if r.HTTPResponse.StatusCode == http.StatusNotFound { |
|||
r.Error = awserr.New("NotFoundError", "user-data not found", r.Error) |
|||
} |
|||
}) |
|||
|
|||
return output.Content, req.Send() |
|||
} |
|||
|
|||
// GetDynamicData uses the path provided to request information from the EC2
|
|||
// instance metadata service for dynamic data. The content will be returned
|
|||
// as a string, or error if the request failed.
|
|||
func (c *EC2Metadata) GetDynamicData(p string) (string, error) { |
|||
op := &request.Operation{ |
|||
Name: "GetDynamicData", |
|||
HTTPMethod: "GET", |
|||
HTTPPath: path.Join("/", "dynamic", p), |
|||
} |
|||
|
|||
output := &metadataOutput{} |
|||
req := c.NewRequest(op, nil, output) |
|||
|
|||
return output.Content, req.Send() |
|||
} |
|||
|
|||
// GetInstanceIdentityDocument retrieves an identity document describing an
|
|||
// instance. Error is returned if the request fails or is unable to parse
|
|||
// the response.
|
|||
func (c *EC2Metadata) GetInstanceIdentityDocument() (EC2InstanceIdentityDocument, error) { |
|||
resp, err := c.GetDynamicData("instance-identity/document") |
|||
if err != nil { |
|||
return EC2InstanceIdentityDocument{}, |
|||
awserr.New("EC2MetadataRequestError", |
|||
"failed to get EC2 instance identity document", err) |
|||
} |
|||
|
|||
doc := EC2InstanceIdentityDocument{} |
|||
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&doc); err != nil { |
|||
return EC2InstanceIdentityDocument{}, |
|||
awserr.New("SerializationError", |
|||
"failed to decode EC2 instance identity document", err) |
|||
} |
|||
|
|||
return doc, nil |
|||
} |
|||
|
|||
// IAMInfo retrieves IAM info from the metadata API
|
|||
func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) { |
|||
resp, err := c.GetMetadata("iam/info") |
|||
if err != nil { |
|||
return EC2IAMInfo{}, |
|||
awserr.New("EC2MetadataRequestError", |
|||
"failed to get EC2 IAM info", err) |
|||
} |
|||
|
|||
info := EC2IAMInfo{} |
|||
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&info); err != nil { |
|||
return EC2IAMInfo{}, |
|||
awserr.New("SerializationError", |
|||
"failed to decode EC2 IAM info", err) |
|||
} |
|||
|
|||
if info.Code != "Success" { |
|||
errMsg := fmt.Sprintf("failed to get EC2 IAM Info (%s)", info.Code) |
|||
return EC2IAMInfo{}, |
|||
awserr.New("EC2MetadataError", errMsg, nil) |
|||
} |
|||
|
|||
return info, nil |
|||
} |
|||
|
|||
// Region returns the region the instance is running in.
|
|||
func (c *EC2Metadata) Region() (string, error) { |
|||
resp, err := c.GetMetadata("placement/availability-zone") |
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
|
|||
// returns region without the suffix. Eg: us-west-2a becomes us-west-2
|
|||
return resp[:len(resp)-1], nil |
|||
} |
|||
|
|||
// Available returns if the application has access to the EC2 Metadata service.
|
|||
// Can be used to determine if application is running within an EC2 Instance and
|
|||
// the metadata service is available.
|
|||
func (c *EC2Metadata) Available() bool { |
|||
if _, err := c.GetMetadata("instance-id"); err != nil { |
|||
return false |
|||
} |
|||
|
|||
return true |
|||
} |
|||
|
|||
// An EC2IAMInfo provides the shape for unmarshalling
|
|||
// an IAM info from the metadata API
|
|||
type EC2IAMInfo struct { |
|||
Code string |
|||
LastUpdated time.Time |
|||
InstanceProfileArn string |
|||
InstanceProfileID string |
|||
} |
|||
|
|||
// An EC2InstanceIdentityDocument provides the shape for unmarshalling
|
|||
// an instance identity document
|
|||
type EC2InstanceIdentityDocument struct { |
|||
DevpayProductCodes []string `json:"devpayProductCodes"` |
|||
AvailabilityZone string `json:"availabilityZone"` |
|||
PrivateIP string `json:"privateIp"` |
|||
Version string `json:"version"` |
|||
Region string `json:"region"` |
|||
InstanceID string `json:"instanceId"` |
|||
BillingProducts []string `json:"billingProducts"` |
|||
InstanceType string `json:"instanceType"` |
|||
AccountID string `json:"accountId"` |
|||
PendingTime time.Time `json:"pendingTime"` |
|||
ImageID string `json:"imageId"` |
|||
KernelID string `json:"kernelId"` |
|||
RamdiskID string `json:"ramdiskId"` |
|||
Architecture string `json:"architecture"` |
|||
} |
@ -0,0 +1,124 @@ |
|||
// Package ec2metadata provides the client for making API calls to the
|
|||
// EC2 Metadata service.
|
|||
package ec2metadata |
|||
|
|||
import ( |
|||
"bytes" |
|||
"errors" |
|||
"io" |
|||
"net/http" |
|||
"time" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws" |
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
"github.com/aws/aws-sdk-go/aws/client" |
|||
"github.com/aws/aws-sdk-go/aws/client/metadata" |
|||
"github.com/aws/aws-sdk-go/aws/request" |
|||
) |
|||
|
|||
// ServiceName is the name of the service.
|
|||
const ServiceName = "ec2metadata" |
|||
|
|||
// A EC2Metadata is an EC2 Metadata service Client.
|
|||
type EC2Metadata struct { |
|||
*client.Client |
|||
} |
|||
|
|||
// New creates a new instance of the EC2Metadata client with a session.
|
|||
// This client is safe to use across multiple goroutines.
|
|||
//
|
|||
//
|
|||
// Example:
|
|||
// // Create a EC2Metadata client from just a session.
|
|||
// svc := ec2metadata.New(mySession)
|
|||
//
|
|||
// // Create a EC2Metadata client with additional configuration
|
|||
// svc := ec2metadata.New(mySession, aws.NewConfig().WithLogLevel(aws.LogDebugHTTPBody))
|
|||
func New(p client.ConfigProvider, cfgs ...*aws.Config) *EC2Metadata { |
|||
c := p.ClientConfig(ServiceName, cfgs...) |
|||
return NewClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion) |
|||
} |
|||
|
|||
// NewClient returns a new EC2Metadata client. Should be used to create
|
|||
// a client when not using a session. Generally using just New with a session
|
|||
// is preferred.
|
|||
//
|
|||
// If an unmodified HTTP client is provided from the stdlib default, or no client
|
|||
// the EC2RoleProvider's EC2Metadata HTTP client's timeout will be shortened.
|
|||
// To disable this set Config.EC2MetadataDisableTimeoutOverride to false. Enabled by default.
|
|||
func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion string, opts ...func(*client.Client)) *EC2Metadata { |
|||
if !aws.BoolValue(cfg.EC2MetadataDisableTimeoutOverride) && httpClientZero(cfg.HTTPClient) { |
|||
// If the http client is unmodified and this feature is not disabled
|
|||
// set custom timeouts for EC2Metadata requests.
|
|||
cfg.HTTPClient = &http.Client{ |
|||
// use a shorter timeout than default because the metadata
|
|||
// service is local if it is running, and to fail faster
|
|||
// if not running on an ec2 instance.
|
|||
Timeout: 5 * time.Second, |
|||
} |
|||
} |
|||
|
|||
svc := &EC2Metadata{ |
|||
Client: client.New( |
|||
cfg, |
|||
metadata.ClientInfo{ |
|||
ServiceName: ServiceName, |
|||
Endpoint: endpoint, |
|||
APIVersion: "latest", |
|||
}, |
|||
handlers, |
|||
), |
|||
} |
|||
|
|||
svc.Handlers.Unmarshal.PushBack(unmarshalHandler) |
|||
svc.Handlers.UnmarshalError.PushBack(unmarshalError) |
|||
svc.Handlers.Validate.Clear() |
|||
svc.Handlers.Validate.PushBack(validateEndpointHandler) |
|||
|
|||
// Add additional options to the service config
|
|||
for _, option := range opts { |
|||
option(svc.Client) |
|||
} |
|||
|
|||
return svc |
|||
} |
|||
|
|||
func httpClientZero(c *http.Client) bool { |
|||
return c == nil || (c.Transport == nil && c.CheckRedirect == nil && c.Jar == nil && c.Timeout == 0) |
|||
} |
|||
|
|||
type metadataOutput struct { |
|||
Content string |
|||
} |
|||
|
|||
func unmarshalHandler(r *request.Request) { |
|||
defer r.HTTPResponse.Body.Close() |
|||
b := &bytes.Buffer{} |
|||
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil { |
|||
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata respose", err) |
|||
return |
|||
} |
|||
|
|||
if data, ok := r.Data.(*metadataOutput); ok { |
|||
data.Content = b.String() |
|||
} |
|||
} |
|||
|
|||
func unmarshalError(r *request.Request) { |
|||
defer r.HTTPResponse.Body.Close() |
|||
b := &bytes.Buffer{} |
|||
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil { |
|||
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata error respose", err) |
|||
return |
|||
} |
|||
|
|||
// Response body format is not consistent between metadata endpoints.
|
|||
// Grab the error message as a string and include that as the source error
|
|||
r.Error = awserr.New("EC2MetadataError", "failed to make EC2Metadata request", errors.New(b.String())) |
|||
} |
|||
|
|||
func validateEndpointHandler(r *request.Request) { |
|||
if r.ClientInfo.Endpoint == "" { |
|||
r.Error = aws.ErrMissingEndpoint |
|||
} |
|||
} |
@ -0,0 +1,17 @@ |
|||
package aws |
|||
|
|||
import "github.com/aws/aws-sdk-go/aws/awserr" |
|||
|
|||
var ( |
|||
// ErrMissingRegion is an error that is returned if region configuration is
|
|||
// not found.
|
|||
//
|
|||
// @readonly
|
|||
ErrMissingRegion = awserr.New("MissingRegion", "could not find region configuration", nil) |
|||
|
|||
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be
|
|||
// resolved for a service.
|
|||
//
|
|||
// @readonly
|
|||
ErrMissingEndpoint = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil) |
|||
) |
@ -0,0 +1,112 @@ |
|||
package aws |
|||
|
|||
import ( |
|||
"log" |
|||
"os" |
|||
) |
|||
|
|||
// A LogLevelType defines the level logging should be performed at. Used to instruct
|
|||
// the SDK which statements should be logged.
|
|||
type LogLevelType uint |
|||
|
|||
// LogLevel returns the pointer to a LogLevel. Should be used to workaround
|
|||
// not being able to take the address of a non-composite literal.
|
|||
func LogLevel(l LogLevelType) *LogLevelType { |
|||
return &l |
|||
} |
|||
|
|||
// Value returns the LogLevel value or the default value LogOff if the LogLevel
|
|||
// is nil. Safe to use on nil value LogLevelTypes.
|
|||
func (l *LogLevelType) Value() LogLevelType { |
|||
if l != nil { |
|||
return *l |
|||
} |
|||
return LogOff |
|||
} |
|||
|
|||
// Matches returns true if the v LogLevel is enabled by this LogLevel. Should be
|
|||
// used with logging sub levels. Is safe to use on nil value LogLevelTypes. If
|
|||
// LogLevel is nill, will default to LogOff comparison.
|
|||
func (l *LogLevelType) Matches(v LogLevelType) bool { |
|||
c := l.Value() |
|||
return c&v == v |
|||
} |
|||
|
|||
// AtLeast returns true if this LogLevel is at least high enough to satisfies v.
|
|||
// Is safe to use on nil value LogLevelTypes. If LogLevel is nill, will default
|
|||
// to LogOff comparison.
|
|||
func (l *LogLevelType) AtLeast(v LogLevelType) bool { |
|||
c := l.Value() |
|||
return c >= v |
|||
} |
|||
|
|||
const ( |
|||
// LogOff states that no logging should be performed by the SDK. This is the
|
|||
// default state of the SDK, and should be use to disable all logging.
|
|||
LogOff LogLevelType = iota * 0x1000 |
|||
|
|||
// LogDebug state that debug output should be logged by the SDK. This should
|
|||
// be used to inspect request made and responses received.
|
|||
LogDebug |
|||
) |
|||
|
|||
// Debug Logging Sub Levels
|
|||
const ( |
|||
// LogDebugWithSigning states that the SDK should log request signing and
|
|||
// presigning events. This should be used to log the signing details of
|
|||
// requests for debugging. Will also enable LogDebug.
|
|||
LogDebugWithSigning LogLevelType = LogDebug | (1 << iota) |
|||
|
|||
// LogDebugWithHTTPBody states the SDK should log HTTP request and response
|
|||
// HTTP bodys in addition to the headers and path. This should be used to
|
|||
// see the body content of requests and responses made while using the SDK
|
|||
// Will also enable LogDebug.
|
|||
LogDebugWithHTTPBody |
|||
|
|||
// LogDebugWithRequestRetries states the SDK should log when service requests will
|
|||
// be retried. This should be used to log when you want to log when service
|
|||
// requests are being retried. Will also enable LogDebug.
|
|||
LogDebugWithRequestRetries |
|||
|
|||
// LogDebugWithRequestErrors states the SDK should log when service requests fail
|
|||
// to build, send, validate, or unmarshal.
|
|||
LogDebugWithRequestErrors |
|||
) |
|||
|
|||
// A Logger is a minimalistic interface for the SDK to log messages to. Should
|
|||
// be used to provide custom logging writers for the SDK to use.
|
|||
type Logger interface { |
|||
Log(...interface{}) |
|||
} |
|||
|
|||
// A LoggerFunc is a convenience type to convert a function taking a variadic
|
|||
// list of arguments and wrap it so the Logger interface can be used.
|
|||
//
|
|||
// Example:
|
|||
// s3.New(sess, &aws.Config{Logger: aws.LoggerFunc(func(args ...interface{}) {
|
|||
// fmt.Fprintln(os.Stdout, args...)
|
|||
// })})
|
|||
type LoggerFunc func(...interface{}) |
|||
|
|||
// Log calls the wrapped function with the arguments provided
|
|||
func (f LoggerFunc) Log(args ...interface{}) { |
|||
f(args...) |
|||
} |
|||
|
|||
// NewDefaultLogger returns a Logger which will write log messages to stdout, and
|
|||
// use same formatting runes as the stdlib log.Logger
|
|||
func NewDefaultLogger() Logger { |
|||
return &defaultLogger{ |
|||
logger: log.New(os.Stdout, "", log.LstdFlags), |
|||
} |
|||
} |
|||
|
|||
// A defaultLogger provides a minimalistic logger satisfying the Logger interface.
|
|||
type defaultLogger struct { |
|||
logger *log.Logger |
|||
} |
|||
|
|||
// Log logs the parameters to the stdlib logger. See log.Println.
|
|||
func (l defaultLogger) Log(args ...interface{}) { |
|||
l.logger.Println(args...) |
|||
} |
@ -0,0 +1,187 @@ |
|||
package request |
|||
|
|||
import ( |
|||
"fmt" |
|||
"strings" |
|||
) |
|||
|
|||
// A Handlers provides a collection of request handlers for various
|
|||
// stages of handling requests.
|
|||
type Handlers struct { |
|||
Validate HandlerList |
|||
Build HandlerList |
|||
Sign HandlerList |
|||
Send HandlerList |
|||
ValidateResponse HandlerList |
|||
Unmarshal HandlerList |
|||
UnmarshalMeta HandlerList |
|||
UnmarshalError HandlerList |
|||
Retry HandlerList |
|||
AfterRetry HandlerList |
|||
} |
|||
|
|||
// Copy returns of this handler's lists.
|
|||
func (h *Handlers) Copy() Handlers { |
|||
return Handlers{ |
|||
Validate: h.Validate.copy(), |
|||
Build: h.Build.copy(), |
|||
Sign: h.Sign.copy(), |
|||
Send: h.Send.copy(), |
|||
ValidateResponse: h.ValidateResponse.copy(), |
|||
Unmarshal: h.Unmarshal.copy(), |
|||
UnmarshalError: h.UnmarshalError.copy(), |
|||
UnmarshalMeta: h.UnmarshalMeta.copy(), |
|||
Retry: h.Retry.copy(), |
|||
AfterRetry: h.AfterRetry.copy(), |
|||
} |
|||
} |
|||
|
|||
// Clear removes callback functions for all handlers
|
|||
func (h *Handlers) Clear() { |
|||
h.Validate.Clear() |
|||
h.Build.Clear() |
|||
h.Send.Clear() |
|||
h.Sign.Clear() |
|||
h.Unmarshal.Clear() |
|||
h.UnmarshalMeta.Clear() |
|||
h.UnmarshalError.Clear() |
|||
h.ValidateResponse.Clear() |
|||
h.Retry.Clear() |
|||
h.AfterRetry.Clear() |
|||
} |
|||
|
|||
// A HandlerListRunItem represents an entry in the HandlerList which
|
|||
// is being run.
|
|||
type HandlerListRunItem struct { |
|||
Index int |
|||
Handler NamedHandler |
|||
Request *Request |
|||
} |
|||
|
|||
// A HandlerList manages zero or more handlers in a list.
|
|||
type HandlerList struct { |
|||
list []NamedHandler |
|||
|
|||
// Called after each request handler in the list is called. If set
|
|||
// and the func returns true the HandlerList will continue to iterate
|
|||
// over the request handlers. If false is returned the HandlerList
|
|||
// will stop iterating.
|
|||
//
|
|||
// Should be used if extra logic to be performed between each handler
|
|||
// in the list. This can be used to terminate a list's iteration
|
|||
// based on a condition such as error like, HandlerListStopOnError.
|
|||
// Or for logging like HandlerListLogItem.
|
|||
AfterEachFn func(item HandlerListRunItem) bool |
|||
} |
|||
|
|||
// A NamedHandler is a struct that contains a name and function callback.
|
|||
type NamedHandler struct { |
|||
Name string |
|||
Fn func(*Request) |
|||
} |
|||
|
|||
// copy creates a copy of the handler list.
|
|||
func (l *HandlerList) copy() HandlerList { |
|||
n := HandlerList{ |
|||
AfterEachFn: l.AfterEachFn, |
|||
} |
|||
n.list = append([]NamedHandler{}, l.list...) |
|||
return n |
|||
} |
|||
|
|||
// Clear clears the handler list.
|
|||
func (l *HandlerList) Clear() { |
|||
l.list = []NamedHandler{} |
|||
} |
|||
|
|||
// Len returns the number of handlers in the list.
|
|||
func (l *HandlerList) Len() int { |
|||
return len(l.list) |
|||
} |
|||
|
|||
// PushBack pushes handler f to the back of the handler list.
|
|||
func (l *HandlerList) PushBack(f func(*Request)) { |
|||
l.list = append(l.list, NamedHandler{"__anonymous", f}) |
|||
} |
|||
|
|||
// PushFront pushes handler f to the front of the handler list.
|
|||
func (l *HandlerList) PushFront(f func(*Request)) { |
|||
l.list = append([]NamedHandler{{"__anonymous", f}}, l.list...) |
|||
} |
|||
|
|||
// PushBackNamed pushes named handler f to the back of the handler list.
|
|||
func (l *HandlerList) PushBackNamed(n NamedHandler) { |
|||
l.list = append(l.list, n) |
|||
} |
|||
|
|||
// PushFrontNamed pushes named handler f to the front of the handler list.
|
|||
func (l *HandlerList) PushFrontNamed(n NamedHandler) { |
|||
l.list = append([]NamedHandler{n}, l.list...) |
|||
} |
|||
|
|||
// Remove removes a NamedHandler n
|
|||
func (l *HandlerList) Remove(n NamedHandler) { |
|||
newlist := []NamedHandler{} |
|||
for _, m := range l.list { |
|||
if m.Name != n.Name { |
|||
newlist = append(newlist, m) |
|||
} |
|||
} |
|||
l.list = newlist |
|||
} |
|||
|
|||
// Run executes all handlers in the list with a given request object.
|
|||
func (l *HandlerList) Run(r *Request) { |
|||
for i, h := range l.list { |
|||
h.Fn(r) |
|||
item := HandlerListRunItem{ |
|||
Index: i, Handler: h, Request: r, |
|||
} |
|||
if l.AfterEachFn != nil && !l.AfterEachFn(item) { |
|||
return |
|||
} |
|||
} |
|||
} |
|||
|
|||
// HandlerListLogItem logs the request handler and the state of the
|
|||
// request's Error value. Always returns true to continue iterating
|
|||
// request handlers in a HandlerList.
|
|||
func HandlerListLogItem(item HandlerListRunItem) bool { |
|||
if item.Request.Config.Logger == nil { |
|||
return true |
|||
} |
|||
item.Request.Config.Logger.Log("DEBUG: RequestHandler", |
|||
item.Index, item.Handler.Name, item.Request.Error) |
|||
|
|||
return true |
|||
} |
|||
|
|||
// HandlerListStopOnError returns false to stop the HandlerList iterating
|
|||
// over request handlers if Request.Error is not nil. True otherwise
|
|||
// to continue iterating.
|
|||
func HandlerListStopOnError(item HandlerListRunItem) bool { |
|||
return item.Request.Error == nil |
|||
} |
|||
|
|||
// MakeAddToUserAgentHandler will add the name/version pair to the User-Agent request
|
|||
// header. If the extra parameters are provided they will be added as metadata to the
|
|||
// name/version pair resulting in the following format.
|
|||
// "name/version (extra0; extra1; ...)"
|
|||
// The user agent part will be concatenated with this current request's user agent string.
|
|||
func MakeAddToUserAgentHandler(name, version string, extra ...string) func(*Request) { |
|||
ua := fmt.Sprintf("%s/%s", name, version) |
|||
if len(extra) > 0 { |
|||
ua += fmt.Sprintf(" (%s)", strings.Join(extra, "; ")) |
|||
} |
|||
return func(r *Request) { |
|||
AddToUserAgent(r, ua) |
|||
} |
|||
} |
|||
|
|||
// MakeAddToUserAgentFreeFormHandler adds the input to the User-Agent request header.
|
|||
// The input string will be concatenated with the current request's user agent string.
|
|||
func MakeAddToUserAgentFreeFormHandler(s string) func(*Request) { |
|||
return func(r *Request) { |
|||
AddToUserAgent(r, s) |
|||
} |
|||
} |
@ -0,0 +1,24 @@ |
|||
package request |
|||
|
|||
import ( |
|||
"io" |
|||
"net/http" |
|||
"net/url" |
|||
) |
|||
|
|||
func copyHTTPRequest(r *http.Request, body io.ReadCloser) *http.Request { |
|||
req := new(http.Request) |
|||
*req = *r |
|||
req.URL = &url.URL{} |
|||
*req.URL = *r.URL |
|||
req.Body = body |
|||
|
|||
req.Header = http.Header{} |
|||
for k, v := range r.Header { |
|||
for _, vv := range v { |
|||
req.Header.Add(k, vv) |
|||
} |
|||
} |
|||
|
|||
return req |
|||
} |
@ -0,0 +1,58 @@ |
|||
package request |
|||
|
|||
import ( |
|||
"io" |
|||
"sync" |
|||
) |
|||
|
|||
// offsetReader is a thread-safe io.ReadCloser to prevent racing
|
|||
// with retrying requests
|
|||
type offsetReader struct { |
|||
buf io.ReadSeeker |
|||
lock sync.Mutex |
|||
closed bool |
|||
} |
|||
|
|||
func newOffsetReader(buf io.ReadSeeker, offset int64) *offsetReader { |
|||
reader := &offsetReader{} |
|||
buf.Seek(offset, 0) |
|||
|
|||
reader.buf = buf |
|||
return reader |
|||
} |
|||
|
|||
// Close will close the instance of the offset reader's access to
|
|||
// the underlying io.ReadSeeker.
|
|||
func (o *offsetReader) Close() error { |
|||
o.lock.Lock() |
|||
defer o.lock.Unlock() |
|||
o.closed = true |
|||
return nil |
|||
} |
|||
|
|||
// Read is a thread-safe read of the underlying io.ReadSeeker
|
|||
func (o *offsetReader) Read(p []byte) (int, error) { |
|||
o.lock.Lock() |
|||
defer o.lock.Unlock() |
|||
|
|||
if o.closed { |
|||
return 0, io.EOF |
|||
} |
|||
|
|||
return o.buf.Read(p) |
|||
} |
|||
|
|||
// Seek is a thread-safe seeking operation.
|
|||
func (o *offsetReader) Seek(offset int64, whence int) (int64, error) { |
|||
o.lock.Lock() |
|||
defer o.lock.Unlock() |
|||
|
|||
return o.buf.Seek(offset, whence) |
|||
} |
|||
|
|||
// CloseAndCopy will return a new offsetReader with a copy of the old buffer
|
|||
// and close the old buffer.
|
|||
func (o *offsetReader) CloseAndCopy(offset int64) *offsetReader { |
|||
o.Close() |
|||
return newOffsetReader(o.buf, offset) |
|||
} |
@ -0,0 +1,344 @@ |
|||
package request |
|||
|
|||
import ( |
|||
"bytes" |
|||
"fmt" |
|||
"io" |
|||
"net/http" |
|||
"net/url" |
|||
"reflect" |
|||
"strings" |
|||
"time" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws" |
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
"github.com/aws/aws-sdk-go/aws/client/metadata" |
|||
) |
|||
|
|||
// A Request is the service request to be made.
|
|||
type Request struct { |
|||
Config aws.Config |
|||
ClientInfo metadata.ClientInfo |
|||
Handlers Handlers |
|||
|
|||
Retryer |
|||
Time time.Time |
|||
ExpireTime time.Duration |
|||
Operation *Operation |
|||
HTTPRequest *http.Request |
|||
HTTPResponse *http.Response |
|||
Body io.ReadSeeker |
|||
BodyStart int64 // offset from beginning of Body that the request body starts
|
|||
Params interface{} |
|||
Error error |
|||
Data interface{} |
|||
RequestID string |
|||
RetryCount int |
|||
Retryable *bool |
|||
RetryDelay time.Duration |
|||
NotHoist bool |
|||
SignedHeaderVals http.Header |
|||
LastSignedAt time.Time |
|||
|
|||
built bool |
|||
|
|||
// Need to persist an intermideant body betweend the input Body and HTTP
|
|||
// request body because the HTTP Client's transport can maintain a reference
|
|||
// to the HTTP request's body after the client has returned. This value is
|
|||
// safe to use concurrently and rewraps the input Body for each HTTP request.
|
|||
safeBody *offsetReader |
|||
} |
|||
|
|||
// An Operation is the service API operation to be made.
|
|||
type Operation struct { |
|||
Name string |
|||
HTTPMethod string |
|||
HTTPPath string |
|||
*Paginator |
|||
} |
|||
|
|||
// Paginator keeps track of pagination configuration for an API operation.
|
|||
type Paginator struct { |
|||
InputTokens []string |
|||
OutputTokens []string |
|||
LimitToken string |
|||
TruncationToken string |
|||
} |
|||
|
|||
// New returns a new Request pointer for the service API
|
|||
// operation and parameters.
|
|||
//
|
|||
// Params is any value of input parameters to be the request payload.
|
|||
// Data is pointer value to an object which the request's response
|
|||
// payload will be deserialized to.
|
|||
func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers, |
|||
retryer Retryer, operation *Operation, params interface{}, data interface{}) *Request { |
|||
|
|||
method := operation.HTTPMethod |
|||
if method == "" { |
|||
method = "POST" |
|||
} |
|||
|
|||
httpReq, _ := http.NewRequest(method, "", nil) |
|||
|
|||
var err error |
|||
httpReq.URL, err = url.Parse(clientInfo.Endpoint + operation.HTTPPath) |
|||
if err != nil { |
|||
httpReq.URL = &url.URL{} |
|||
err = awserr.New("InvalidEndpointURL", "invalid endpoint uri", err) |
|||
} |
|||
|
|||
r := &Request{ |
|||
Config: cfg, |
|||
ClientInfo: clientInfo, |
|||
Handlers: handlers.Copy(), |
|||
|
|||
Retryer: retryer, |
|||
Time: time.Now(), |
|||
ExpireTime: 0, |
|||
Operation: operation, |
|||
HTTPRequest: httpReq, |
|||
Body: nil, |
|||
Params: params, |
|||
Error: err, |
|||
Data: data, |
|||
} |
|||
r.SetBufferBody([]byte{}) |
|||
|
|||
return r |
|||
} |
|||
|
|||
// WillRetry returns if the request's can be retried.
|
|||
func (r *Request) WillRetry() bool { |
|||
return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries() |
|||
} |
|||
|
|||
// ParamsFilled returns if the request's parameters have been populated
|
|||
// and the parameters are valid. False is returned if no parameters are
|
|||
// provided or invalid.
|
|||
func (r *Request) ParamsFilled() bool { |
|||
return r.Params != nil && reflect.ValueOf(r.Params).Elem().IsValid() |
|||
} |
|||
|
|||
// DataFilled returns true if the request's data for response deserialization
|
|||
// target has been set and is a valid. False is returned if data is not
|
|||
// set, or is invalid.
|
|||
func (r *Request) DataFilled() bool { |
|||
return r.Data != nil && reflect.ValueOf(r.Data).Elem().IsValid() |
|||
} |
|||
|
|||
// SetBufferBody will set the request's body bytes that will be sent to
|
|||
// the service API.
|
|||
func (r *Request) SetBufferBody(buf []byte) { |
|||
r.SetReaderBody(bytes.NewReader(buf)) |
|||
} |
|||
|
|||
// SetStringBody sets the body of the request to be backed by a string.
|
|||
func (r *Request) SetStringBody(s string) { |
|||
r.SetReaderBody(strings.NewReader(s)) |
|||
} |
|||
|
|||
// SetReaderBody will set the request's body reader.
|
|||
func (r *Request) SetReaderBody(reader io.ReadSeeker) { |
|||
r.Body = reader |
|||
r.ResetBody() |
|||
} |
|||
|
|||
// Presign returns the request's signed URL. Error will be returned
|
|||
// if the signing fails.
|
|||
func (r *Request) Presign(expireTime time.Duration) (string, error) { |
|||
r.ExpireTime = expireTime |
|||
r.NotHoist = false |
|||
r.Sign() |
|||
if r.Error != nil { |
|||
return "", r.Error |
|||
} |
|||
return r.HTTPRequest.URL.String(), nil |
|||
} |
|||
|
|||
// PresignRequest behaves just like presign, but hoists all headers and signs them.
|
|||
// Also returns the signed hash back to the user
|
|||
func (r *Request) PresignRequest(expireTime time.Duration) (string, http.Header, error) { |
|||
r.ExpireTime = expireTime |
|||
r.NotHoist = true |
|||
r.Sign() |
|||
if r.Error != nil { |
|||
return "", nil, r.Error |
|||
} |
|||
return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil |
|||
} |
|||
|
|||
func debugLogReqError(r *Request, stage string, retrying bool, err error) { |
|||
if !r.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) { |
|||
return |
|||
} |
|||
|
|||
retryStr := "not retrying" |
|||
if retrying { |
|||
retryStr = "will retry" |
|||
} |
|||
|
|||
r.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v", |
|||
stage, r.ClientInfo.ServiceName, r.Operation.Name, retryStr, err)) |
|||
} |
|||
|
|||
// Build will build the request's object so it can be signed and sent
|
|||
// to the service. Build will also validate all the request's parameters.
|
|||
// Anny additional build Handlers set on this request will be run
|
|||
// in the order they were set.
|
|||
//
|
|||
// The request will only be built once. Multiple calls to build will have
|
|||
// no effect.
|
|||
//
|
|||
// If any Validate or Build errors occur the build will stop and the error
|
|||
// which occurred will be returned.
|
|||
func (r *Request) Build() error { |
|||
if !r.built { |
|||
r.Handlers.Validate.Run(r) |
|||
if r.Error != nil { |
|||
debugLogReqError(r, "Validate Request", false, r.Error) |
|||
return r.Error |
|||
} |
|||
r.Handlers.Build.Run(r) |
|||
if r.Error != nil { |
|||
debugLogReqError(r, "Build Request", false, r.Error) |
|||
return r.Error |
|||
} |
|||
r.built = true |
|||
} |
|||
|
|||
return r.Error |
|||
} |
|||
|
|||
// Sign will sign the request returning error if errors are encountered.
|
|||
//
|
|||
// Send will build the request prior to signing. All Sign Handlers will
|
|||
// be executed in the order they were set.
|
|||
func (r *Request) Sign() error { |
|||
r.Build() |
|||
if r.Error != nil { |
|||
debugLogReqError(r, "Build Request", false, r.Error) |
|||
return r.Error |
|||
} |
|||
|
|||
r.Handlers.Sign.Run(r) |
|||
return r.Error |
|||
} |
|||
|
|||
// ResetBody rewinds the request body backto its starting position, and
|
|||
// set's the HTTP Request body reference. When the body is read prior
|
|||
// to being sent in the HTTP request it will need to be rewound.
|
|||
func (r *Request) ResetBody() { |
|||
if r.safeBody != nil { |
|||
r.safeBody.Close() |
|||
} |
|||
|
|||
r.safeBody = newOffsetReader(r.Body, r.BodyStart) |
|||
r.HTTPRequest.Body = r.safeBody |
|||
} |
|||
|
|||
// GetBody will return an io.ReadSeeker of the Request's underlying
|
|||
// input body with a concurrency safe wrapper.
|
|||
func (r *Request) GetBody() io.ReadSeeker { |
|||
return r.safeBody |
|||
} |
|||
|
|||
// Send will send the request returning error if errors are encountered.
|
|||
//
|
|||
// Send will sign the request prior to sending. All Send Handlers will
|
|||
// be executed in the order they were set.
|
|||
//
|
|||
// Canceling a request is non-deterministic. If a request has been canceled,
|
|||
// then the transport will choose, randomly, one of the state channels during
|
|||
// reads or getting the connection.
|
|||
//
|
|||
// readLoop() and getConn(req *Request, cm connectMethod)
|
|||
// https://github.com/golang/go/blob/master/src/net/http/transport.go
|
|||
//
|
|||
// Send will not close the request.Request's body.
|
|||
func (r *Request) Send() error { |
|||
for { |
|||
if aws.BoolValue(r.Retryable) { |
|||
if r.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) { |
|||
r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d", |
|||
r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount)) |
|||
} |
|||
|
|||
// The previous http.Request will have a reference to the r.Body
|
|||
// and the HTTP Client's Transport may still be reading from
|
|||
// the request's body even though the Client's Do returned.
|
|||
r.HTTPRequest = copyHTTPRequest(r.HTTPRequest, nil) |
|||
r.ResetBody() |
|||
|
|||
// Closing response body to ensure that no response body is leaked
|
|||
// between retry attempts.
|
|||
if r.HTTPResponse != nil && r.HTTPResponse.Body != nil { |
|||
r.HTTPResponse.Body.Close() |
|||
} |
|||
} |
|||
|
|||
r.Sign() |
|||
if r.Error != nil { |
|||
return r.Error |
|||
} |
|||
|
|||
r.Retryable = nil |
|||
|
|||
r.Handlers.Send.Run(r) |
|||
if r.Error != nil { |
|||
if strings.Contains(r.Error.Error(), "net/http: request canceled") { |
|||
return r.Error |
|||
} |
|||
|
|||
err := r.Error |
|||
r.Handlers.Retry.Run(r) |
|||
r.Handlers.AfterRetry.Run(r) |
|||
if r.Error != nil { |
|||
debugLogReqError(r, "Send Request", false, r.Error) |
|||
return r.Error |
|||
} |
|||
debugLogReqError(r, "Send Request", true, err) |
|||
continue |
|||
} |
|||
r.Handlers.UnmarshalMeta.Run(r) |
|||
r.Handlers.ValidateResponse.Run(r) |
|||
if r.Error != nil { |
|||
err := r.Error |
|||
r.Handlers.UnmarshalError.Run(r) |
|||
r.Handlers.Retry.Run(r) |
|||
r.Handlers.AfterRetry.Run(r) |
|||
if r.Error != nil { |
|||
debugLogReqError(r, "Validate Response", false, r.Error) |
|||
return r.Error |
|||
} |
|||
debugLogReqError(r, "Validate Response", true, err) |
|||
continue |
|||
} |
|||
|
|||
r.Handlers.Unmarshal.Run(r) |
|||
if r.Error != nil { |
|||
err := r.Error |
|||
r.Handlers.Retry.Run(r) |
|||
r.Handlers.AfterRetry.Run(r) |
|||
if r.Error != nil { |
|||
debugLogReqError(r, "Unmarshal Response", false, r.Error) |
|||
return r.Error |
|||
} |
|||
debugLogReqError(r, "Unmarshal Response", true, err) |
|||
continue |
|||
} |
|||
|
|||
break |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// AddToUserAgent adds the string to the end of the request's current user agent.
|
|||
func AddToUserAgent(r *Request, s string) { |
|||
curUA := r.HTTPRequest.Header.Get("User-Agent") |
|||
if len(curUA) > 0 { |
|||
s = curUA + " " + s |
|||
} |
|||
r.HTTPRequest.Header.Set("User-Agent", s) |
|||
} |
@ -0,0 +1,104 @@ |
|||
package request |
|||
|
|||
import ( |
|||
"reflect" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws" |
|||
"github.com/aws/aws-sdk-go/aws/awsutil" |
|||
) |
|||
|
|||
//type Paginater interface {
|
|||
// HasNextPage() bool
|
|||
// NextPage() *Request
|
|||
// EachPage(fn func(data interface{}, isLastPage bool) (shouldContinue bool)) error
|
|||
//}
|
|||
|
|||
// HasNextPage returns true if this request has more pages of data available.
|
|||
func (r *Request) HasNextPage() bool { |
|||
return len(r.nextPageTokens()) > 0 |
|||
} |
|||
|
|||
// nextPageTokens returns the tokens to use when asking for the next page of
|
|||
// data.
|
|||
func (r *Request) nextPageTokens() []interface{} { |
|||
if r.Operation.Paginator == nil { |
|||
return nil |
|||
} |
|||
|
|||
if r.Operation.TruncationToken != "" { |
|||
tr, _ := awsutil.ValuesAtPath(r.Data, r.Operation.TruncationToken) |
|||
if len(tr) == 0 { |
|||
return nil |
|||
} |
|||
|
|||
switch v := tr[0].(type) { |
|||
case *bool: |
|||
if !aws.BoolValue(v) { |
|||
return nil |
|||
} |
|||
case bool: |
|||
if v == false { |
|||
return nil |
|||
} |
|||
} |
|||
} |
|||
|
|||
tokens := []interface{}{} |
|||
tokenAdded := false |
|||
for _, outToken := range r.Operation.OutputTokens { |
|||
v, _ := awsutil.ValuesAtPath(r.Data, outToken) |
|||
if len(v) > 0 { |
|||
tokens = append(tokens, v[0]) |
|||
tokenAdded = true |
|||
} else { |
|||
tokens = append(tokens, nil) |
|||
} |
|||
} |
|||
if !tokenAdded { |
|||
return nil |
|||
} |
|||
|
|||
return tokens |
|||
} |
|||
|
|||
// NextPage returns a new Request that can be executed to return the next
|
|||
// page of result data. Call .Send() on this request to execute it.
|
|||
func (r *Request) NextPage() *Request { |
|||
tokens := r.nextPageTokens() |
|||
if len(tokens) == 0 { |
|||
return nil |
|||
} |
|||
|
|||
data := reflect.New(reflect.TypeOf(r.Data).Elem()).Interface() |
|||
nr := New(r.Config, r.ClientInfo, r.Handlers, r.Retryer, r.Operation, awsutil.CopyOf(r.Params), data) |
|||
for i, intok := range nr.Operation.InputTokens { |
|||
awsutil.SetValueAtPath(nr.Params, intok, tokens[i]) |
|||
} |
|||
return nr |
|||
} |
|||
|
|||
// EachPage iterates over each page of a paginated request object. The fn
|
|||
// parameter should be a function with the following sample signature:
|
|||
//
|
|||
// func(page *T, lastPage bool) bool {
|
|||
// return true // return false to stop iterating
|
|||
// }
|
|||
//
|
|||
// Where "T" is the structure type matching the output structure of the given
|
|||
// operation. For example, a request object generated by
|
|||
// DynamoDB.ListTablesRequest() would expect to see dynamodb.ListTablesOutput
|
|||
// as the structure "T". The lastPage value represents whether the page is
|
|||
// the last page of data or not. The return value of this function should
|
|||
// return true to keep iterating or false to stop.
|
|||
func (r *Request) EachPage(fn func(data interface{}, isLastPage bool) (shouldContinue bool)) error { |
|||
for page := r; page != nil; page = page.NextPage() { |
|||
if err := page.Send(); err != nil { |
|||
return err |
|||
} |
|||
if getNextPage := fn(page.Data, !page.HasNextPage()); !getNextPage { |
|||
return page.Error |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
@ -0,0 +1,101 @@ |
|||
package request |
|||
|
|||
import ( |
|||
"time" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws" |
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
) |
|||
|
|||
// Retryer is an interface to control retry logic for a given service.
|
|||
// The default implementation used by most services is the service.DefaultRetryer
|
|||
// structure, which contains basic retry logic using exponential backoff.
|
|||
type Retryer interface { |
|||
RetryRules(*Request) time.Duration |
|||
ShouldRetry(*Request) bool |
|||
MaxRetries() int |
|||
} |
|||
|
|||
// WithRetryer sets a config Retryer value to the given Config returning it
|
|||
// for chaining.
|
|||
func WithRetryer(cfg *aws.Config, retryer Retryer) *aws.Config { |
|||
cfg.Retryer = retryer |
|||
return cfg |
|||
} |
|||
|
|||
// retryableCodes is a collection of service response codes which are retry-able
|
|||
// without any further action.
|
|||
var retryableCodes = map[string]struct{}{ |
|||
"RequestError": {}, |
|||
"RequestTimeout": {}, |
|||
} |
|||
|
|||
var throttleCodes = map[string]struct{}{ |
|||
"ProvisionedThroughputExceededException": {}, |
|||
"Throttling": {}, |
|||
"ThrottlingException": {}, |
|||
"RequestLimitExceeded": {}, |
|||
"RequestThrottled": {}, |
|||
"LimitExceededException": {}, // Deleting 10+ DynamoDb tables at once
|
|||
"TooManyRequestsException": {}, // Lambda functions
|
|||
} |
|||
|
|||
// credsExpiredCodes is a collection of error codes which signify the credentials
|
|||
// need to be refreshed. Expired tokens require refreshing of credentials, and
|
|||
// resigning before the request can be retried.
|
|||
var credsExpiredCodes = map[string]struct{}{ |
|||
"ExpiredToken": {}, |
|||
"ExpiredTokenException": {}, |
|||
"RequestExpired": {}, // EC2 Only
|
|||
} |
|||
|
|||
func isCodeThrottle(code string) bool { |
|||
_, ok := throttleCodes[code] |
|||
return ok |
|||
} |
|||
|
|||
func isCodeRetryable(code string) bool { |
|||
if _, ok := retryableCodes[code]; ok { |
|||
return true |
|||
} |
|||
|
|||
return isCodeExpiredCreds(code) |
|||
} |
|||
|
|||
func isCodeExpiredCreds(code string) bool { |
|||
_, ok := credsExpiredCodes[code] |
|||
return ok |
|||
} |
|||
|
|||
// IsErrorRetryable returns whether the error is retryable, based on its Code.
|
|||
// Returns false if the request has no Error set.
|
|||
func (r *Request) IsErrorRetryable() bool { |
|||
if r.Error != nil { |
|||
if err, ok := r.Error.(awserr.Error); ok { |
|||
return isCodeRetryable(err.Code()) |
|||
} |
|||
} |
|||
return false |
|||
} |
|||
|
|||
// IsErrorThrottle returns whether the error is to be throttled based on its code.
|
|||
// Returns false if the request has no Error set
|
|||
func (r *Request) IsErrorThrottle() bool { |
|||
if r.Error != nil { |
|||
if err, ok := r.Error.(awserr.Error); ok { |
|||
return isCodeThrottle(err.Code()) |
|||
} |
|||
} |
|||
return false |
|||
} |
|||
|
|||
// IsErrorExpired returns whether the error code is a credential expiry error.
|
|||
// Returns false if the request has no Error set.
|
|||
func (r *Request) IsErrorExpired() bool { |
|||
if r.Error != nil { |
|||
if err, ok := r.Error.(awserr.Error); ok { |
|||
return isCodeExpiredCreds(err.Code()) |
|||
} |
|||
} |
|||
return false |
|||
} |
@ -0,0 +1,234 @@ |
|||
package request |
|||
|
|||
import ( |
|||
"bytes" |
|||
"fmt" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
) |
|||
|
|||
const ( |
|||
// InvalidParameterErrCode is the error code for invalid parameters errors
|
|||
InvalidParameterErrCode = "InvalidParameter" |
|||
// ParamRequiredErrCode is the error code for required parameter errors
|
|||
ParamRequiredErrCode = "ParamRequiredError" |
|||
// ParamMinValueErrCode is the error code for fields with too low of a
|
|||
// number value.
|
|||
ParamMinValueErrCode = "ParamMinValueError" |
|||
// ParamMinLenErrCode is the error code for fields without enough elements.
|
|||
ParamMinLenErrCode = "ParamMinLenError" |
|||
) |
|||
|
|||
// Validator provides a way for types to perform validation logic on their
|
|||
// input values that external code can use to determine if a type's values
|
|||
// are valid.
|
|||
type Validator interface { |
|||
Validate() error |
|||
} |
|||
|
|||
// An ErrInvalidParams provides wrapping of invalid parameter errors found when
|
|||
// validating API operation input parameters.
|
|||
type ErrInvalidParams struct { |
|||
// Context is the base context of the invalid parameter group.
|
|||
Context string |
|||
errs []ErrInvalidParam |
|||
} |
|||
|
|||
// Add adds a new invalid parameter error to the collection of invalid
|
|||
// parameters. The context of the invalid parameter will be updated to reflect
|
|||
// this collection.
|
|||
func (e *ErrInvalidParams) Add(err ErrInvalidParam) { |
|||
err.SetContext(e.Context) |
|||
e.errs = append(e.errs, err) |
|||
} |
|||
|
|||
// AddNested adds the invalid parameter errors from another ErrInvalidParams
|
|||
// value into this collection. The nested errors will have their nested context
|
|||
// updated and base context to reflect the merging.
|
|||
//
|
|||
// Use for nested validations errors.
|
|||
func (e *ErrInvalidParams) AddNested(nestedCtx string, nested ErrInvalidParams) { |
|||
for _, err := range nested.errs { |
|||
err.SetContext(e.Context) |
|||
err.AddNestedContext(nestedCtx) |
|||
e.errs = append(e.errs, err) |
|||
} |
|||
} |
|||
|
|||
// Len returns the number of invalid parameter errors
|
|||
func (e ErrInvalidParams) Len() int { |
|||
return len(e.errs) |
|||
} |
|||
|
|||
// Code returns the code of the error
|
|||
func (e ErrInvalidParams) Code() string { |
|||
return InvalidParameterErrCode |
|||
} |
|||
|
|||
// Message returns the message of the error
|
|||
func (e ErrInvalidParams) Message() string { |
|||
return fmt.Sprintf("%d validation error(s) found.", len(e.errs)) |
|||
} |
|||
|
|||
// Error returns the string formatted form of the invalid parameters.
|
|||
func (e ErrInvalidParams) Error() string { |
|||
w := &bytes.Buffer{} |
|||
fmt.Fprintf(w, "%s: %s\n", e.Code(), e.Message()) |
|||
|
|||
for _, err := range e.errs { |
|||
fmt.Fprintf(w, "- %s\n", err.Message()) |
|||
} |
|||
|
|||
return w.String() |
|||
} |
|||
|
|||
// OrigErr returns the invalid parameters as a awserr.BatchedErrors value
|
|||
func (e ErrInvalidParams) OrigErr() error { |
|||
return awserr.NewBatchError( |
|||
InvalidParameterErrCode, e.Message(), e.OrigErrs()) |
|||
} |
|||
|
|||
// OrigErrs returns a slice of the invalid parameters
|
|||
func (e ErrInvalidParams) OrigErrs() []error { |
|||
errs := make([]error, len(e.errs)) |
|||
for i := 0; i < len(errs); i++ { |
|||
errs[i] = e.errs[i] |
|||
} |
|||
|
|||
return errs |
|||
} |
|||
|
|||
// An ErrInvalidParam represents an invalid parameter error type.
|
|||
type ErrInvalidParam interface { |
|||
awserr.Error |
|||
|
|||
// Field name the error occurred on.
|
|||
Field() string |
|||
|
|||
// SetContext updates the context of the error.
|
|||
SetContext(string) |
|||
|
|||
// AddNestedContext updates the error's context to include a nested level.
|
|||
AddNestedContext(string) |
|||
} |
|||
|
|||
type errInvalidParam struct { |
|||
context string |
|||
nestedContext string |
|||
field string |
|||
code string |
|||
msg string |
|||
} |
|||
|
|||
// Code returns the error code for the type of invalid parameter.
|
|||
func (e *errInvalidParam) Code() string { |
|||
return e.code |
|||
} |
|||
|
|||
// Message returns the reason the parameter was invalid, and its context.
|
|||
func (e *errInvalidParam) Message() string { |
|||
return fmt.Sprintf("%s, %s.", e.msg, e.Field()) |
|||
} |
|||
|
|||
// Error returns the string version of the invalid parameter error.
|
|||
func (e *errInvalidParam) Error() string { |
|||
return fmt.Sprintf("%s: %s", e.code, e.Message()) |
|||
} |
|||
|
|||
// OrigErr returns nil, Implemented for awserr.Error interface.
|
|||
func (e *errInvalidParam) OrigErr() error { |
|||
return nil |
|||
} |
|||
|
|||
// Field Returns the field and context the error occurred.
|
|||
func (e *errInvalidParam) Field() string { |
|||
field := e.context |
|||
if len(field) > 0 { |
|||
field += "." |
|||
} |
|||
if len(e.nestedContext) > 0 { |
|||
field += fmt.Sprintf("%s.", e.nestedContext) |
|||
} |
|||
field += e.field |
|||
|
|||
return field |
|||
} |
|||
|
|||
// SetContext updates the base context of the error.
|
|||
func (e *errInvalidParam) SetContext(ctx string) { |
|||
e.context = ctx |
|||
} |
|||
|
|||
// AddNestedContext prepends a context to the field's path.
|
|||
func (e *errInvalidParam) AddNestedContext(ctx string) { |
|||
if len(e.nestedContext) == 0 { |
|||
e.nestedContext = ctx |
|||
} else { |
|||
e.nestedContext = fmt.Sprintf("%s.%s", ctx, e.nestedContext) |
|||
} |
|||
|
|||
} |
|||
|
|||
// An ErrParamRequired represents an required parameter error.
|
|||
type ErrParamRequired struct { |
|||
errInvalidParam |
|||
} |
|||
|
|||
// NewErrParamRequired creates a new required parameter error.
|
|||
func NewErrParamRequired(field string) *ErrParamRequired { |
|||
return &ErrParamRequired{ |
|||
errInvalidParam{ |
|||
code: ParamRequiredErrCode, |
|||
field: field, |
|||
msg: fmt.Sprintf("missing required field"), |
|||
}, |
|||
} |
|||
} |
|||
|
|||
// An ErrParamMinValue represents a minimum value parameter error.
|
|||
type ErrParamMinValue struct { |
|||
errInvalidParam |
|||
min float64 |
|||
} |
|||
|
|||
// NewErrParamMinValue creates a new minimum value parameter error.
|
|||
func NewErrParamMinValue(field string, min float64) *ErrParamMinValue { |
|||
return &ErrParamMinValue{ |
|||
errInvalidParam: errInvalidParam{ |
|||
code: ParamMinValueErrCode, |
|||
field: field, |
|||
msg: fmt.Sprintf("minimum field value of %v", min), |
|||
}, |
|||
min: min, |
|||
} |
|||
} |
|||
|
|||
// MinValue returns the field's require minimum value.
|
|||
//
|
|||
// float64 is returned for both int and float min values.
|
|||
func (e *ErrParamMinValue) MinValue() float64 { |
|||
return e.min |
|||
} |
|||
|
|||
// An ErrParamMinLen represents a minimum length parameter error.
|
|||
type ErrParamMinLen struct { |
|||
errInvalidParam |
|||
min int |
|||
} |
|||
|
|||
// NewErrParamMinLen creates a new minimum length parameter error.
|
|||
func NewErrParamMinLen(field string, min int) *ErrParamMinLen { |
|||
return &ErrParamMinLen{ |
|||
errInvalidParam: errInvalidParam{ |
|||
code: ParamMinValueErrCode, |
|||
field: field, |
|||
msg: fmt.Sprintf("minimum field size of %v", min), |
|||
}, |
|||
min: min, |
|||
} |
|||
} |
|||
|
|||
// MinLen returns the field's required minimum length.
|
|||
func (e *ErrParamMinLen) MinLen() int { |
|||
return e.min |
|||
} |
@ -0,0 +1,223 @@ |
|||
/* |
|||
Package session provides configuration for the SDK's service clients. |
|||
|
|||
Sessions can be shared across all service clients that share the same base |
|||
configuration. The Session is built from the SDK's default configuration and |
|||
request handlers. |
|||
|
|||
Sessions should be cached when possible, because creating a new Session will |
|||
load all configuration values from the environment, and config files each time |
|||
the Session is created. Sharing the Session value across all of your service |
|||
clients will ensure the configuration is loaded the fewest number of times possible. |
|||
|
|||
Concurrency |
|||
|
|||
Sessions are safe to use concurrently as long as the Session is not being |
|||
modified. The SDK will not modify the Session once the Session has been created. |
|||
Creating service clients concurrently from a shared Session is safe. |
|||
|
|||
Sessions from Shared Config |
|||
|
|||
Sessions can be created using the method above that will only load the |
|||
additional config if the AWS_SDK_LOAD_CONFIG environment variable is set. |
|||
Alternatively you can explicitly create a Session with shared config enabled. |
|||
To do this you can use NewSessionWithOptions to configure how the Session will |
|||
be created. Using the NewSessionWithOptions with SharedConfigState set to |
|||
SharedConfigEnabled will create the session as if the AWS_SDK_LOAD_CONFIG |
|||
environment variable was set. |
|||
|
|||
Creating Sessions |
|||
|
|||
When creating Sessions optional aws.Config values can be passed in that will |
|||
override the default, or loaded config values the Session is being created |
|||
with. This allows you to provide additional, or case based, configuration |
|||
as needed. |
|||
|
|||
By default NewSession will only load credentials from the shared credentials |
|||
file (~/.aws/credentials). If the AWS_SDK_LOAD_CONFIG environment variable is |
|||
set to a truthy value the Session will be created from the configuration |
|||
values from the shared config (~/.aws/config) and shared credentials |
|||
(~/.aws/credentials) files. See the section Sessions from Shared Config for |
|||
more information. |
|||
|
|||
Create a Session with the default config and request handlers. With credentials |
|||
region, and profile loaded from the environment and shared config automatically. |
|||
Requires the AWS_PROFILE to be set, or "default" is used. |
|||
|
|||
// Create Session
|
|||
sess, err := session.NewSession() |
|||
|
|||
// Create a Session with a custom region
|
|||
sess, err := session.NewSession(&aws.Config{Region: aws.String("us-east-1")}) |
|||
|
|||
// Create a S3 client instance from a session
|
|||
sess, err := session.NewSession() |
|||
if err != nil { |
|||
// Handle Session creation error
|
|||
} |
|||
svc := s3.New(sess) |
|||
|
|||
Create Session With Option Overrides |
|||
|
|||
In addition to NewSession, Sessions can be created using NewSessionWithOptions. |
|||
This func allows you to control and override how the Session will be created |
|||
through code instead of being driven by environment variables only. |
|||
|
|||
Use NewSessionWithOptions when you want to provide the config profile, or |
|||
override the shared config state (AWS_SDK_LOAD_CONFIG). |
|||
|
|||
// Equivalent to session.New
|
|||
sess, err := session.NewSessionWithOptions(session.Options{}) |
|||
|
|||
// Specify profile to load for the session's config
|
|||
sess, err := session.NewSessionWithOptions(session.Options{ |
|||
Profile: "profile_name", |
|||
}) |
|||
|
|||
// Specify profile for config and region for requests
|
|||
sess, err := session.NewSessionWithOptions(session.Options{ |
|||
Config: aws.Config{Region: aws.String("us-east-1")}, |
|||
Profile: "profile_name", |
|||
}) |
|||
|
|||
// Force enable Shared Config support
|
|||
sess, err := session.NewSessionWithOptions(session.Options{ |
|||
SharedConfigState: SharedConfigEnable, |
|||
}) |
|||
|
|||
Adding Handlers |
|||
|
|||
You can add handlers to a session for processing HTTP requests. All service |
|||
clients that use the session inherit the handlers. For example, the following |
|||
handler logs every request and its payload made by a service client: |
|||
|
|||
// Create a session, and add additional handlers for all service
|
|||
// clients created with the Session to inherit. Adds logging handler.
|
|||
sess, err := session.NewSession() |
|||
sess.Handlers.Send.PushFront(func(r *request.Request) { |
|||
// Log every request made and its payload
|
|||
logger.Println("Request: %s/%s, Payload: %s", |
|||
r.ClientInfo.ServiceName, r.Operation, r.Params) |
|||
}) |
|||
|
|||
Deprecated "New" function |
|||
|
|||
The New session function has been deprecated because it does not provide good |
|||
way to return errors that occur when loading the configuration files and values. |
|||
Because of this, NewSession was created so errors can be retrieved when |
|||
creating a session fails. |
|||
|
|||
Shared Config Fields |
|||
|
|||
By default the SDK will only load the shared credentials file's (~/.aws/credentials) |
|||
credentials values, and all other config is provided by the environment variables, |
|||
SDK defaults, and user provided aws.Config values. |
|||
|
|||
If the AWS_SDK_LOAD_CONFIG environment variable is set, or SharedConfigEnable |
|||
option is used to create the Session the full shared config values will be |
|||
loaded. This includes credentials, region, and support for assume role. In |
|||
addition the Session will load its configuration from both the shared config |
|||
file (~/.aws/config) and shared credentials file (~/.aws/credentials). Both |
|||
files have the same format. |
|||
|
|||
If both config files are present the configuration from both files will be |
|||
read. The Session will be created from configuration values from the shared |
|||
credentials file (~/.aws/credentials) over those in the shared credentials |
|||
file (~/.aws/config). |
|||
|
|||
Credentials are the values the SDK should use for authenticating requests with |
|||
AWS Services. They arfrom a configuration file will need to include both |
|||
aws_access_key_id and aws_secret_access_key must be provided together in the |
|||
same file to be considered valid. The values will be ignored if not a complete |
|||
group. aws_session_token is an optional field that can be provided if both of |
|||
the other two fields are also provided. |
|||
|
|||
aws_access_key_id = AKID |
|||
aws_secret_access_key = SECRET |
|||
aws_session_token = TOKEN |
|||
|
|||
Assume Role values allow you to configure the SDK to assume an IAM role using |
|||
a set of credentials provided in a config file via the source_profile field. |
|||
Both "role_arn" and "source_profile" are required. The SDK does not support |
|||
assuming a role with MFA token Via the Session's constructor. You can use the |
|||
stscreds.AssumeRoleProvider credentials provider to specify custom |
|||
configuration and support for MFA. |
|||
|
|||
role_arn = arn:aws:iam::<account_number>:role/<role_name> |
|||
source_profile = profile_with_creds |
|||
external_id = 1234 |
|||
mfa_serial = not supported! |
|||
role_session_name = session_name |
|||
|
|||
Region is the region the SDK should use for looking up AWS service endpoints |
|||
and signing requests. |
|||
|
|||
region = us-east-1 |
|||
|
|||
Environment Variables |
|||
|
|||
When a Session is created several environment variables can be set to adjust |
|||
how the SDK functions, and what configuration data it loads when creating |
|||
Sessions. All environment values are optional, but some values like credentials |
|||
require multiple of the values to set or the partial values will be ignored. |
|||
All environment variable values are strings unless otherwise noted. |
|||
|
|||
Environment configuration values. If set both Access Key ID and Secret Access |
|||
Key must be provided. Session Token and optionally also be provided, but is |
|||
not required. |
|||
|
|||
# Access Key ID |
|||
AWS_ACCESS_KEY_ID=AKID |
|||
AWS_ACCESS_KEY=AKID # only read if AWS_ACCESS_KEY_ID is not set. |
|||
|
|||
# Secret Access Key |
|||
AWS_SECRET_ACCESS_KEY=SECRET |
|||
AWS_SECRET_KEY=SECRET=SECRET # only read if AWS_SECRET_ACCESS_KEY is not set. |
|||
|
|||
# Session Token |
|||
AWS_SESSION_TOKEN=TOKEN |
|||
|
|||
Region value will instruct the SDK where to make service API requests to. If is |
|||
not provided in the environment the region must be provided before a service |
|||
client request is made. |
|||
|
|||
AWS_REGION=us-east-1 |
|||
|
|||
# AWS_DEFAULT_REGION is only read if AWS_SDK_LOAD_CONFIG is also set, |
|||
# and AWS_REGION is not also set. |
|||
AWS_DEFAULT_REGION=us-east-1 |
|||
|
|||
Profile name the SDK should load use when loading shared config from the |
|||
configuration files. If not provided "default" will be used as the profile name. |
|||
|
|||
AWS_PROFILE=my_profile |
|||
|
|||
# AWS_DEFAULT_PROFILE is only read if AWS_SDK_LOAD_CONFIG is also set, |
|||
# and AWS_PROFILE is not also set. |
|||
AWS_DEFAULT_PROFILE=my_profile |
|||
|
|||
SDK load config instructs the SDK to load the shared config in addition to |
|||
shared credentials. This also expands the configuration loaded so the shared |
|||
credentials will have parity with the shared config file. This also enables |
|||
Region and Profile support for the AWS_DEFAULT_REGION and AWS_DEFAULT_PROFILE |
|||
env values as well. |
|||
|
|||
AWS_SDK_LOAD_CONFIG=1 |
|||
|
|||
Shared credentials file path can be set to instruct the SDK to use an alternative |
|||
file for the shared credentials. If not set the file will be loaded from |
|||
$HOME/.aws/credentials on Linux/Unix based systems, and |
|||
%USERPROFILE%\.aws\credentials on Windows. |
|||
|
|||
AWS_SHARED_CREDENTIALS_FILE=$HOME/my_shared_credentials |
|||
|
|||
Shared config file path can be set to instruct the SDK to use an alternative |
|||
file for the shared config. If not set the file will be loaded from |
|||
$HOME/.aws/config on Linux/Unix based systems, and |
|||
%USERPROFILE%\.aws\config on Windows. |
|||
|
|||
AWS_CONFIG_FILE=$HOME/my_shared_config |
|||
|
|||
|
|||
*/ |
|||
package session |
@ -0,0 +1,188 @@ |
|||
package session |
|||
|
|||
import ( |
|||
"os" |
|||
"path/filepath" |
|||
"strconv" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws/credentials" |
|||
) |
|||
|
|||
// envConfig is a collection of environment values the SDK will read
|
|||
// setup config from. All environment values are optional. But some values
|
|||
// such as credentials require multiple values to be complete or the values
|
|||
// will be ignored.
|
|||
type envConfig struct { |
|||
// Environment configuration values. If set both Access Key ID and Secret Access
|
|||
// Key must be provided. Session Token and optionally also be provided, but is
|
|||
// not required.
|
|||
//
|
|||
// # Access Key ID
|
|||
// AWS_ACCESS_KEY_ID=AKID
|
|||
// AWS_ACCESS_KEY=AKID # only read if AWS_ACCESS_KEY_ID is not set.
|
|||
//
|
|||
// # Secret Access Key
|
|||
// AWS_SECRET_ACCESS_KEY=SECRET
|
|||
// AWS_SECRET_KEY=SECRET=SECRET # only read if AWS_SECRET_ACCESS_KEY is not set.
|
|||
//
|
|||
// # Session Token
|
|||
// AWS_SESSION_TOKEN=TOKEN
|
|||
Creds credentials.Value |
|||
|
|||
// Region value will instruct the SDK where to make service API requests to. If is
|
|||
// not provided in the environment the region must be provided before a service
|
|||
// client request is made.
|
|||
//
|
|||
// AWS_REGION=us-east-1
|
|||
//
|
|||
// # AWS_DEFAULT_REGION is only read if AWS_SDK_LOAD_CONFIG is also set,
|
|||
// # and AWS_REGION is not also set.
|
|||
// AWS_DEFAULT_REGION=us-east-1
|
|||
Region string |
|||
|
|||
// Profile name the SDK should load use when loading shared configuration from the
|
|||
// shared configuration files. If not provided "default" will be used as the
|
|||
// profile name.
|
|||
//
|
|||
// AWS_PROFILE=my_profile
|
|||
//
|
|||
// # AWS_DEFAULT_PROFILE is only read if AWS_SDK_LOAD_CONFIG is also set,
|
|||
// # and AWS_PROFILE is not also set.
|
|||
// AWS_DEFAULT_PROFILE=my_profile
|
|||
Profile string |
|||
|
|||
// SDK load config instructs the SDK to load the shared config in addition to
|
|||
// shared credentials. This also expands the configuration loaded from the shared
|
|||
// credentials to have parity with the shared config file. This also enables
|
|||
// Region and Profile support for the AWS_DEFAULT_REGION and AWS_DEFAULT_PROFILE
|
|||
// env values as well.
|
|||
//
|
|||
// AWS_SDK_LOAD_CONFIG=1
|
|||
EnableSharedConfig bool |
|||
|
|||
// Shared credentials file path can be set to instruct the SDK to use an alternate
|
|||
// file for the shared credentials. If not set the file will be loaded from
|
|||
// $HOME/.aws/credentials on Linux/Unix based systems, and
|
|||
// %USERPROFILE%\.aws\credentials on Windows.
|
|||
//
|
|||
// AWS_SHARED_CREDENTIALS_FILE=$HOME/my_shared_credentials
|
|||
SharedCredentialsFile string |
|||
|
|||
// Shared config file path can be set to instruct the SDK to use an alternate
|
|||
// file for the shared config. If not set the file will be loaded from
|
|||
// $HOME/.aws/config on Linux/Unix based systems, and
|
|||
// %USERPROFILE%\.aws\config on Windows.
|
|||
//
|
|||
// AWS_CONFIG_FILE=$HOME/my_shared_config
|
|||
SharedConfigFile string |
|||
} |
|||
|
|||
var ( |
|||
credAccessEnvKey = []string{ |
|||
"AWS_ACCESS_KEY_ID", |
|||
"AWS_ACCESS_KEY", |
|||
} |
|||
credSecretEnvKey = []string{ |
|||
"AWS_SECRET_ACCESS_KEY", |
|||
"AWS_SECRET_KEY", |
|||
} |
|||
credSessionEnvKey = []string{ |
|||
"AWS_SESSION_TOKEN", |
|||
} |
|||
|
|||
regionEnvKeys = []string{ |
|||
"AWS_REGION", |
|||
"AWS_DEFAULT_REGION", // Only read if AWS_SDK_LOAD_CONFIG is also set
|
|||
} |
|||
profileEnvKeys = []string{ |
|||
"AWS_PROFILE", |
|||
"AWS_DEFAULT_PROFILE", // Only read if AWS_SDK_LOAD_CONFIG is also set
|
|||
} |
|||
) |
|||
|
|||
// loadEnvConfig retrieves the SDK's environment configuration.
|
|||
// See `envConfig` for the values that will be retrieved.
|
|||
//
|
|||
// If the environment variable `AWS_SDK_LOAD_CONFIG` is set to a truthy value
|
|||
// the shared SDK config will be loaded in addition to the SDK's specific
|
|||
// configuration values.
|
|||
func loadEnvConfig() envConfig { |
|||
enableSharedConfig, _ := strconv.ParseBool(os.Getenv("AWS_SDK_LOAD_CONFIG")) |
|||
return envConfigLoad(enableSharedConfig) |
|||
} |
|||
|
|||
// loadEnvSharedConfig retrieves the SDK's environment configuration, and the
|
|||
// SDK shared config. See `envConfig` for the values that will be retrieved.
|
|||
//
|
|||
// Loads the shared configuration in addition to the SDK's specific configuration.
|
|||
// This will load the same values as `loadEnvConfig` if the `AWS_SDK_LOAD_CONFIG`
|
|||
// environment variable is set.
|
|||
func loadSharedEnvConfig() envConfig { |
|||
return envConfigLoad(true) |
|||
} |
|||
|
|||
func envConfigLoad(enableSharedConfig bool) envConfig { |
|||
cfg := envConfig{} |
|||
|
|||
cfg.EnableSharedConfig = enableSharedConfig |
|||
|
|||
setFromEnvVal(&cfg.Creds.AccessKeyID, credAccessEnvKey) |
|||
setFromEnvVal(&cfg.Creds.SecretAccessKey, credSecretEnvKey) |
|||
setFromEnvVal(&cfg.Creds.SessionToken, credSessionEnvKey) |
|||
|
|||
// Require logical grouping of credentials
|
|||
if len(cfg.Creds.AccessKeyID) == 0 || len(cfg.Creds.SecretAccessKey) == 0 { |
|||
cfg.Creds = credentials.Value{} |
|||
} else { |
|||
cfg.Creds.ProviderName = "EnvConfigCredentials" |
|||
} |
|||
|
|||
regionKeys := regionEnvKeys |
|||
profileKeys := profileEnvKeys |
|||
if !cfg.EnableSharedConfig { |
|||
regionKeys = regionKeys[:1] |
|||
profileKeys = profileKeys[:1] |
|||
} |
|||
|
|||
setFromEnvVal(&cfg.Region, regionKeys) |
|||
setFromEnvVal(&cfg.Profile, profileKeys) |
|||
|
|||
cfg.SharedCredentialsFile = sharedCredentialsFilename() |
|||
cfg.SharedConfigFile = sharedConfigFilename() |
|||
|
|||
return cfg |
|||
} |
|||
|
|||
func setFromEnvVal(dst *string, keys []string) { |
|||
for _, k := range keys { |
|||
if v := os.Getenv(k); len(v) > 0 { |
|||
*dst = v |
|||
break |
|||
} |
|||
} |
|||
} |
|||
|
|||
func sharedCredentialsFilename() string { |
|||
if name := os.Getenv("AWS_SHARED_CREDENTIALS_FILE"); len(name) > 0 { |
|||
return name |
|||
} |
|||
|
|||
return filepath.Join(userHomeDir(), ".aws", "credentials") |
|||
} |
|||
|
|||
func sharedConfigFilename() string { |
|||
if name := os.Getenv("AWS_CONFIG_FILE"); len(name) > 0 { |
|||
return name |
|||
} |
|||
|
|||
return filepath.Join(userHomeDir(), ".aws", "config") |
|||
} |
|||
|
|||
func userHomeDir() string { |
|||
homeDir := os.Getenv("HOME") // *nix
|
|||
if len(homeDir) == 0 { // windows
|
|||
homeDir = os.Getenv("USERPROFILE") |
|||
} |
|||
|
|||
return homeDir |
|||
} |
@ -0,0 +1,393 @@ |
|||
package session |
|||
|
|||
import ( |
|||
"fmt" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws" |
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
"github.com/aws/aws-sdk-go/aws/client" |
|||
"github.com/aws/aws-sdk-go/aws/corehandlers" |
|||
"github.com/aws/aws-sdk-go/aws/credentials" |
|||
"github.com/aws/aws-sdk-go/aws/credentials/stscreds" |
|||
"github.com/aws/aws-sdk-go/aws/defaults" |
|||
"github.com/aws/aws-sdk-go/aws/request" |
|||
"github.com/aws/aws-sdk-go/private/endpoints" |
|||
) |
|||
|
|||
// A Session provides a central location to create service clients from and
|
|||
// store configurations and request handlers for those services.
|
|||
//
|
|||
// Sessions are safe to create service clients concurrently, but it is not safe
|
|||
// to mutate the Session concurrently.
|
|||
//
|
|||
// The Session satisfies the service client's client.ClientConfigProvider.
|
|||
type Session struct { |
|||
Config *aws.Config |
|||
Handlers request.Handlers |
|||
} |
|||
|
|||
// New creates a new instance of the handlers merging in the provided configs
|
|||
// on top of the SDK's default configurations. Once the Session is created it
|
|||
// can be mutated to modify the Config or Handlers. The Session is safe to be
|
|||
// read concurrently, but it should not be written to concurrently.
|
|||
//
|
|||
// If the AWS_SDK_LOAD_CONFIG environment is set to a truthy value, the New
|
|||
// method could now encounter an error when loading the configuration. When
|
|||
// The environment variable is set, and an error occurs, New will return a
|
|||
// session that will fail all requests reporting the error that occured while
|
|||
// loading the session. Use NewSession to get the error when creating the
|
|||
// session.
|
|||
//
|
|||
// If the AWS_SDK_LOAD_CONFIG environment variable is set to a truthy value
|
|||
// the shared config file (~/.aws/config) will also be loaded, in addition to
|
|||
// the shared credentials file (~/.aws/config). Values set in both the
|
|||
// shared config, and shared credentials will be taken from the shared
|
|||
// credentials file.
|
|||
//
|
|||
// Deprecated: Use NewSession functiions to create sessions instead. NewSession
|
|||
// has the same functionality as New except an error can be returned when the
|
|||
// func is called instead of waiting to receive an error until a request is made.
|
|||
func New(cfgs ...*aws.Config) *Session { |
|||
// load initial config from environment
|
|||
envCfg := loadEnvConfig() |
|||
|
|||
if envCfg.EnableSharedConfig { |
|||
s, err := newSession(envCfg, cfgs...) |
|||
if err != nil { |
|||
// Old session.New expected all errors to be discovered when
|
|||
// a request is made, and would report the errors then. This
|
|||
// needs to be replicated if an error occurs while creating
|
|||
// the session.
|
|||
msg := "failed to create session with AWS_SDK_LOAD_CONFIG enabled. " + |
|||
"Use session.NewSession to handle errors occuring during session creation." |
|||
|
|||
// Session creation failed, need to report the error and prevent
|
|||
// any requests from succeeding.
|
|||
s = &Session{Config: defaults.Config()} |
|||
s.Config.MergeIn(cfgs...) |
|||
s.Config.Logger.Log("ERROR:", msg, "Error:", err) |
|||
s.Handlers.Validate.PushBack(func(r *request.Request) { |
|||
r.Error = err |
|||
}) |
|||
} |
|||
return s |
|||
} |
|||
|
|||
return oldNewSession(cfgs...) |
|||
} |
|||
|
|||
// NewSession returns a new Session created from SDK defaults, config files,
|
|||
// environment, and user provided config files. Once the Session is created
|
|||
// it can be mutated to modify the Config or Handlers. The Session is safe to
|
|||
// be read concurrently, but it should not be written to concurrently.
|
|||
//
|
|||
// If the AWS_SDK_LOAD_CONFIG environment variable is set to a truthy value
|
|||
// the shared config file (~/.aws/config) will also be loaded in addition to
|
|||
// the shared credentials file (~/.aws/config). Values set in both the
|
|||
// shared config, and shared credentials will be taken from the shared
|
|||
// credentials file. Enabling the Shared Config will also allow the Session
|
|||
// to be built with retrieving credentials with AssumeRole set in the config.
|
|||
//
|
|||
// See the NewSessionWithOptions func for information on how to override or
|
|||
// control through code how the Session will be created. Such as specifing the
|
|||
// config profile, and controlling if shared config is enabled or not.
|
|||
func NewSession(cfgs ...*aws.Config) (*Session, error) { |
|||
envCfg := loadEnvConfig() |
|||
|
|||
return newSession(envCfg, cfgs...) |
|||
} |
|||
|
|||
// SharedConfigState provides the ability to optionally override the state
|
|||
// of the session's creation based on the shared config being enabled or
|
|||
// disabled.
|
|||
type SharedConfigState int |
|||
|
|||
const ( |
|||
// SharedConfigStateFromEnv does not override any state of the
|
|||
// AWS_SDK_LOAD_CONFIG env var. It is the default value of the
|
|||
// SharedConfigState type.
|
|||
SharedConfigStateFromEnv SharedConfigState = iota |
|||
|
|||
// SharedConfigDisable overrides the AWS_SDK_LOAD_CONFIG env var value
|
|||
// and disables the shared config functionality.
|
|||
SharedConfigDisable |
|||
|
|||
// SharedConfigEnable overrides the AWS_SDK_LOAD_CONFIG env var value
|
|||
// and enables the shared config functionality.
|
|||
SharedConfigEnable |
|||
) |
|||
|
|||
// Options provides the means to control how a Session is created and what
|
|||
// configuration values will be loaded.
|
|||
//
|
|||
type Options struct { |
|||
// Provides config values for the SDK to use when creating service clients
|
|||
// and making API requests to services. Any value set in with this field
|
|||
// will override the associated value provided by the SDK defaults,
|
|||
// environment or config files where relevent.
|
|||
//
|
|||
// If not set, configuration values from from SDK defaults, environment,
|
|||
// config will be used.
|
|||
Config aws.Config |
|||
|
|||
// Overrides the config profile the Session should be created from. If not
|
|||
// set the value of the environment variable will be loaded (AWS_PROFILE,
|
|||
// or AWS_DEFAULT_PROFILE if the Shared Config is enabled).
|
|||
//
|
|||
// If not set and environment variables are not set the "default"
|
|||
// (DefaultSharedConfigProfile) will be used as the profile to load the
|
|||
// session config from.
|
|||
Profile string |
|||
|
|||
// Instructs how the Session will be created based on the AWS_SDK_LOAD_CONFIG
|
|||
// environment variable. By default a Session will be created using the
|
|||
// value provided by the AWS_SDK_LOAD_CONFIG environment variable.
|
|||
//
|
|||
// Setting this value to SharedConfigEnable or SharedConfigDisable
|
|||
// will allow you to override the AWS_SDK_LOAD_CONFIG environment variable
|
|||
// and enable or disable the shared config functionality.
|
|||
SharedConfigState SharedConfigState |
|||
} |
|||
|
|||
// NewSessionWithOptions returns a new Session created from SDK defaults, config files,
|
|||
// environment, and user provided config files. This func uses the Options
|
|||
// values to configure how the Session is created.
|
|||
//
|
|||
// If the AWS_SDK_LOAD_CONFIG environment variable is set to a truthy value
|
|||
// the shared config file (~/.aws/config) will also be loaded in addition to
|
|||
// the shared credentials file (~/.aws/config). Values set in both the
|
|||
// shared config, and shared credentials will be taken from the shared
|
|||
// credentials file. Enabling the Shared Config will also allow the Session
|
|||
// to be built with retrieving credentials with AssumeRole set in the config.
|
|||
//
|
|||
// // Equivalent to session.New
|
|||
// sess, err := session.NewSessionWithOptions(session.Options{})
|
|||
//
|
|||
// // Specify profile to load for the session's config
|
|||
// sess, err := session.NewSessionWithOptions(session.Options{
|
|||
// Profile: "profile_name",
|
|||
// })
|
|||
//
|
|||
// // Specify profile for config and region for requests
|
|||
// sess, err := session.NewSessionWithOptions(session.Options{
|
|||
// Config: aws.Config{Region: aws.String("us-east-1")},
|
|||
// Profile: "profile_name",
|
|||
// })
|
|||
//
|
|||
// // Force enable Shared Config support
|
|||
// sess, err := session.NewSessionWithOptions(session.Options{
|
|||
// SharedConfigState: SharedConfigEnable,
|
|||
// })
|
|||
func NewSessionWithOptions(opts Options) (*Session, error) { |
|||
var envCfg envConfig |
|||
if opts.SharedConfigState == SharedConfigEnable { |
|||
envCfg = loadSharedEnvConfig() |
|||
} else { |
|||
envCfg = loadEnvConfig() |
|||
} |
|||
|
|||
if len(opts.Profile) > 0 { |
|||
envCfg.Profile = opts.Profile |
|||
} |
|||
|
|||
switch opts.SharedConfigState { |
|||
case SharedConfigDisable: |
|||
envCfg.EnableSharedConfig = false |
|||
case SharedConfigEnable: |
|||
envCfg.EnableSharedConfig = true |
|||
} |
|||
|
|||
return newSession(envCfg, &opts.Config) |
|||
} |
|||
|
|||
// Must is a helper function to ensure the Session is valid and there was no
|
|||
// error when calling a NewSession function.
|
|||
//
|
|||
// This helper is intended to be used in variable initialization to load the
|
|||
// Session and configuration at startup. Such as:
|
|||
//
|
|||
// var sess = session.Must(session.NewSession())
|
|||
func Must(sess *Session, err error) *Session { |
|||
if err != nil { |
|||
panic(err) |
|||
} |
|||
|
|||
return sess |
|||
} |
|||
|
|||
func oldNewSession(cfgs ...*aws.Config) *Session { |
|||
cfg := defaults.Config() |
|||
handlers := defaults.Handlers() |
|||
|
|||
// Apply the passed in configs so the configuration can be applied to the
|
|||
// default credential chain
|
|||
cfg.MergeIn(cfgs...) |
|||
cfg.Credentials = defaults.CredChain(cfg, handlers) |
|||
|
|||
// Reapply any passed in configs to override credentials if set
|
|||
cfg.MergeIn(cfgs...) |
|||
|
|||
s := &Session{ |
|||
Config: cfg, |
|||
Handlers: handlers, |
|||
} |
|||
|
|||
initHandlers(s) |
|||
|
|||
return s |
|||
} |
|||
|
|||
func newSession(envCfg envConfig, cfgs ...*aws.Config) (*Session, error) { |
|||
cfg := defaults.Config() |
|||
handlers := defaults.Handlers() |
|||
|
|||
// Get a merged version of the user provided config to determine if
|
|||
// credentials were.
|
|||
userCfg := &aws.Config{} |
|||
userCfg.MergeIn(cfgs...) |
|||
|
|||
// Order config files will be loaded in with later files overwriting
|
|||
// previous config file values.
|
|||
cfgFiles := []string{envCfg.SharedConfigFile, envCfg.SharedCredentialsFile} |
|||
if !envCfg.EnableSharedConfig { |
|||
// The shared config file (~/.aws/config) is only loaded if instructed
|
|||
// to load via the envConfig.EnableSharedConfig (AWS_SDK_LOAD_CONFIG).
|
|||
cfgFiles = cfgFiles[1:] |
|||
} |
|||
|
|||
// Load additional config from file(s)
|
|||
sharedCfg, err := loadSharedConfig(envCfg.Profile, cfgFiles) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
mergeConfigSrcs(cfg, userCfg, envCfg, sharedCfg, handlers) |
|||
|
|||
s := &Session{ |
|||
Config: cfg, |
|||
Handlers: handlers, |
|||
} |
|||
|
|||
initHandlers(s) |
|||
|
|||
return s, nil |
|||
} |
|||
|
|||
func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg sharedConfig, handlers request.Handlers) { |
|||
// Merge in user provided configuration
|
|||
cfg.MergeIn(userCfg) |
|||
|
|||
// Region if not already set by user
|
|||
if len(aws.StringValue(cfg.Region)) == 0 { |
|||
if len(envCfg.Region) > 0 { |
|||
cfg.WithRegion(envCfg.Region) |
|||
} else if envCfg.EnableSharedConfig && len(sharedCfg.Region) > 0 { |
|||
cfg.WithRegion(sharedCfg.Region) |
|||
} |
|||
} |
|||
|
|||
// Configure credentials if not already set
|
|||
if cfg.Credentials == credentials.AnonymousCredentials && userCfg.Credentials == nil { |
|||
if len(envCfg.Creds.AccessKeyID) > 0 { |
|||
cfg.Credentials = credentials.NewStaticCredentialsFromCreds( |
|||
envCfg.Creds, |
|||
) |
|||
} else if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.RoleARN) > 0 && sharedCfg.AssumeRoleSource != nil { |
|||
cfgCp := *cfg |
|||
cfgCp.Credentials = credentials.NewStaticCredentialsFromCreds( |
|||
sharedCfg.AssumeRoleSource.Creds, |
|||
) |
|||
cfg.Credentials = stscreds.NewCredentials( |
|||
&Session{ |
|||
Config: &cfgCp, |
|||
Handlers: handlers.Copy(), |
|||
}, |
|||
sharedCfg.AssumeRole.RoleARN, |
|||
func(opt *stscreds.AssumeRoleProvider) { |
|||
opt.RoleSessionName = sharedCfg.AssumeRole.RoleSessionName |
|||
|
|||
if len(sharedCfg.AssumeRole.ExternalID) > 0 { |
|||
opt.ExternalID = aws.String(sharedCfg.AssumeRole.ExternalID) |
|||
} |
|||
|
|||
// MFA not supported
|
|||
}, |
|||
) |
|||
} else if len(sharedCfg.Creds.AccessKeyID) > 0 { |
|||
cfg.Credentials = credentials.NewStaticCredentialsFromCreds( |
|||
sharedCfg.Creds, |
|||
) |
|||
} else { |
|||
// Fallback to default credentials provider, include mock errors
|
|||
// for the credential chain so user can identify why credentials
|
|||
// failed to be retrieved.
|
|||
cfg.Credentials = credentials.NewCredentials(&credentials.ChainProvider{ |
|||
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors), |
|||
Providers: []credentials.Provider{ |
|||
&credProviderError{Err: awserr.New("EnvAccessKeyNotFound", "failed to find credentials in the environment.", nil)}, |
|||
&credProviderError{Err: awserr.New("SharedCredsLoad", fmt.Sprintf("failed to load profile, %s.", envCfg.Profile), nil)}, |
|||
defaults.RemoteCredProvider(*cfg, handlers), |
|||
}, |
|||
}) |
|||
} |
|||
} |
|||
} |
|||
|
|||
type credProviderError struct { |
|||
Err error |
|||
} |
|||
|
|||
var emptyCreds = credentials.Value{} |
|||
|
|||
func (c credProviderError) Retrieve() (credentials.Value, error) { |
|||
return credentials.Value{}, c.Err |
|||
} |
|||
func (c credProviderError) IsExpired() bool { |
|||
return true |
|||
} |
|||
|
|||
func initHandlers(s *Session) { |
|||
// Add the Validate parameter handler if it is not disabled.
|
|||
s.Handlers.Validate.Remove(corehandlers.ValidateParametersHandler) |
|||
if !aws.BoolValue(s.Config.DisableParamValidation) { |
|||
s.Handlers.Validate.PushBackNamed(corehandlers.ValidateParametersHandler) |
|||
} |
|||
} |
|||
|
|||
// Copy creates and returns a copy of the current Session, coping the config
|
|||
// and handlers. If any additional configs are provided they will be merged
|
|||
// on top of the Session's copied config.
|
|||
//
|
|||
// // Create a copy of the current Session, configured for the us-west-2 region.
|
|||
// sess.Copy(&aws.Config{Region: aws.String("us-west-2")})
|
|||
func (s *Session) Copy(cfgs ...*aws.Config) *Session { |
|||
newSession := &Session{ |
|||
Config: s.Config.Copy(cfgs...), |
|||
Handlers: s.Handlers.Copy(), |
|||
} |
|||
|
|||
initHandlers(newSession) |
|||
|
|||
return newSession |
|||
} |
|||
|
|||
// ClientConfig satisfies the client.ConfigProvider interface and is used to
|
|||
// configure the service client instances. Passing the Session to the service
|
|||
// client's constructor (New) will use this method to configure the client.
|
|||
func (s *Session) ClientConfig(serviceName string, cfgs ...*aws.Config) client.Config { |
|||
s = s.Copy(cfgs...) |
|||
endpoint, signingRegion := endpoints.NormalizeEndpoint( |
|||
aws.StringValue(s.Config.Endpoint), |
|||
serviceName, |
|||
aws.StringValue(s.Config.Region), |
|||
aws.BoolValue(s.Config.DisableSSL), |
|||
aws.BoolValue(s.Config.UseDualStack), |
|||
) |
|||
|
|||
return client.Config{ |
|||
Config: s.Config, |
|||
Handlers: s.Handlers, |
|||
Endpoint: endpoint, |
|||
SigningRegion: signingRegion, |
|||
} |
|||
} |
@ -0,0 +1,294 @@ |
|||
package session |
|||
|
|||
import ( |
|||
"fmt" |
|||
"os" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
"github.com/aws/aws-sdk-go/aws/credentials" |
|||
"github.com/go-ini/ini" |
|||
) |
|||
|
|||
const ( |
|||
// Static Credentials group
|
|||
accessKeyIDKey = `aws_access_key_id` // group required
|
|||
secretAccessKey = `aws_secret_access_key` // group required
|
|||
sessionTokenKey = `aws_session_token` // optional
|
|||
|
|||
// Assume Role Credentials group
|
|||
roleArnKey = `role_arn` // group required
|
|||
sourceProfileKey = `source_profile` // group required
|
|||
externalIDKey = `external_id` // optional
|
|||
mfaSerialKey = `mfa_serial` // optional
|
|||
roleSessionNameKey = `role_session_name` // optional
|
|||
|
|||
// Additional Config fields
|
|||
regionKey = `region` |
|||
|
|||
// DefaultSharedConfigProfile is the default profile to be used when
|
|||
// loading configuration from the config files if another profile name
|
|||
// is not provided.
|
|||
DefaultSharedConfigProfile = `default` |
|||
) |
|||
|
|||
type assumeRoleConfig struct { |
|||
RoleARN string |
|||
SourceProfile string |
|||
ExternalID string |
|||
MFASerial string |
|||
RoleSessionName string |
|||
} |
|||
|
|||
// sharedConfig represents the configuration fields of the SDK config files.
|
|||
type sharedConfig struct { |
|||
// Credentials values from the config file. Both aws_access_key_id
|
|||
// and aws_secret_access_key must be provided together in the same file
|
|||
// to be considered valid. The values will be ignored if not a complete group.
|
|||
// aws_session_token is an optional field that can be provided if both of the
|
|||
// other two fields are also provided.
|
|||
//
|
|||
// aws_access_key_id
|
|||
// aws_secret_access_key
|
|||
// aws_session_token
|
|||
Creds credentials.Value |
|||
|
|||
AssumeRole assumeRoleConfig |
|||
AssumeRoleSource *sharedConfig |
|||
|
|||
// Region is the region the SDK should use for looking up AWS service endpoints
|
|||
// and signing requests.
|
|||
//
|
|||
// region
|
|||
Region string |
|||
} |
|||
|
|||
type sharedConfigFile struct { |
|||
Filename string |
|||
IniData *ini.File |
|||
} |
|||
|
|||
// loadSharedConfig retrieves the configuration from the list of files
|
|||
// using the profile provided. The order the files are listed will determine
|
|||
// precedence. Values in subsequent files will overwrite values defined in
|
|||
// earlier files.
|
|||
//
|
|||
// For example, given two files A and B. Both define credentials. If the order
|
|||
// of the files are A then B, B's credential values will be used instead of A's.
|
|||
//
|
|||
// See sharedConfig.setFromFile for information how the config files
|
|||
// will be loaded.
|
|||
func loadSharedConfig(profile string, filenames []string) (sharedConfig, error) { |
|||
if len(profile) == 0 { |
|||
profile = DefaultSharedConfigProfile |
|||
} |
|||
|
|||
files, err := loadSharedConfigIniFiles(filenames) |
|||
if err != nil { |
|||
return sharedConfig{}, err |
|||
} |
|||
|
|||
cfg := sharedConfig{} |
|||
if err = cfg.setFromIniFiles(profile, files); err != nil { |
|||
return sharedConfig{}, err |
|||
} |
|||
|
|||
if len(cfg.AssumeRole.SourceProfile) > 0 { |
|||
if err := cfg.setAssumeRoleSource(profile, files); err != nil { |
|||
return sharedConfig{}, err |
|||
} |
|||
} |
|||
|
|||
return cfg, nil |
|||
} |
|||
|
|||
func loadSharedConfigIniFiles(filenames []string) ([]sharedConfigFile, error) { |
|||
files := make([]sharedConfigFile, 0, len(filenames)) |
|||
|
|||
for _, filename := range filenames { |
|||
if _, err := os.Stat(filename); os.IsNotExist(err) { |
|||
// Trim files from the list that don't exist.
|
|||
continue |
|||
} |
|||
|
|||
f, err := ini.Load(filename) |
|||
if err != nil { |
|||
return nil, SharedConfigLoadError{Filename: filename} |
|||
} |
|||
|
|||
files = append(files, sharedConfigFile{ |
|||
Filename: filename, IniData: f, |
|||
}) |
|||
} |
|||
|
|||
return files, nil |
|||
} |
|||
|
|||
func (cfg *sharedConfig) setAssumeRoleSource(origProfile string, files []sharedConfigFile) error { |
|||
var assumeRoleSrc sharedConfig |
|||
|
|||
// Multiple level assume role chains are not support
|
|||
if cfg.AssumeRole.SourceProfile == origProfile { |
|||
assumeRoleSrc = *cfg |
|||
assumeRoleSrc.AssumeRole = assumeRoleConfig{} |
|||
} else { |
|||
err := assumeRoleSrc.setFromIniFiles(cfg.AssumeRole.SourceProfile, files) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
} |
|||
|
|||
if len(assumeRoleSrc.Creds.AccessKeyID) == 0 { |
|||
return SharedConfigAssumeRoleError{RoleARN: cfg.AssumeRole.RoleARN} |
|||
} |
|||
|
|||
cfg.AssumeRoleSource = &assumeRoleSrc |
|||
|
|||
return nil |
|||
} |
|||
|
|||
func (cfg *sharedConfig) setFromIniFiles(profile string, files []sharedConfigFile) error { |
|||
// Trim files from the list that don't exist.
|
|||
for _, f := range files { |
|||
if err := cfg.setFromIniFile(profile, f); err != nil { |
|||
if _, ok := err.(SharedConfigProfileNotExistsError); ok { |
|||
// Ignore proviles missings
|
|||
continue |
|||
} |
|||
return err |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// setFromFile loads the configuration from the file using
|
|||
// the profile provided. A sharedConfig pointer type value is used so that
|
|||
// multiple config file loadings can be chained.
|
|||
//
|
|||
// Only loads complete logically grouped values, and will not set fields in cfg
|
|||
// for incomplete grouped values in the config. Such as credentials. For example
|
|||
// if a config file only includes aws_access_key_id but no aws_secret_access_key
|
|||
// the aws_access_key_id will be ignored.
|
|||
func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile) error { |
|||
section, err := file.IniData.GetSection(profile) |
|||
if err != nil { |
|||
// Fallback to to alternate profile name: profile <name>
|
|||
section, err = file.IniData.GetSection(fmt.Sprintf("profile %s", profile)) |
|||
if err != nil { |
|||
return SharedConfigProfileNotExistsError{Profile: profile, Err: err} |
|||
} |
|||
} |
|||
|
|||
// Shared Credentials
|
|||
akid := section.Key(accessKeyIDKey).String() |
|||
secret := section.Key(secretAccessKey).String() |
|||
if len(akid) > 0 && len(secret) > 0 { |
|||
cfg.Creds = credentials.Value{ |
|||
AccessKeyID: akid, |
|||
SecretAccessKey: secret, |
|||
SessionToken: section.Key(sessionTokenKey).String(), |
|||
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", file.Filename), |
|||
} |
|||
} |
|||
|
|||
// Assume Role
|
|||
roleArn := section.Key(roleArnKey).String() |
|||
srcProfile := section.Key(sourceProfileKey).String() |
|||
if len(roleArn) > 0 && len(srcProfile) > 0 { |
|||
cfg.AssumeRole = assumeRoleConfig{ |
|||
RoleARN: roleArn, |
|||
SourceProfile: srcProfile, |
|||
ExternalID: section.Key(externalIDKey).String(), |
|||
MFASerial: section.Key(mfaSerialKey).String(), |
|||
RoleSessionName: section.Key(roleSessionNameKey).String(), |
|||
} |
|||
} |
|||
|
|||
// Region
|
|||
if v := section.Key(regionKey).String(); len(v) > 0 { |
|||
cfg.Region = v |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// SharedConfigLoadError is an error for the shared config file failed to load.
|
|||
type SharedConfigLoadError struct { |
|||
Filename string |
|||
Err error |
|||
} |
|||
|
|||
// Code is the short id of the error.
|
|||
func (e SharedConfigLoadError) Code() string { |
|||
return "SharedConfigLoadError" |
|||
} |
|||
|
|||
// Message is the description of the error
|
|||
func (e SharedConfigLoadError) Message() string { |
|||
return fmt.Sprintf("failed to load config file, %s", e.Filename) |
|||
} |
|||
|
|||
// OrigErr is the underlying error that caused the failure.
|
|||
func (e SharedConfigLoadError) OrigErr() error { |
|||
return e.Err |
|||
} |
|||
|
|||
// Error satisfies the error interface.
|
|||
func (e SharedConfigLoadError) Error() string { |
|||
return awserr.SprintError(e.Code(), e.Message(), "", e.Err) |
|||
} |
|||
|
|||
// SharedConfigProfileNotExistsError is an error for the shared config when
|
|||
// the profile was not find in the config file.
|
|||
type SharedConfigProfileNotExistsError struct { |
|||
Profile string |
|||
Err error |
|||
} |
|||
|
|||
// Code is the short id of the error.
|
|||
func (e SharedConfigProfileNotExistsError) Code() string { |
|||
return "SharedConfigProfileNotExistsError" |
|||
} |
|||
|
|||
// Message is the description of the error
|
|||
func (e SharedConfigProfileNotExistsError) Message() string { |
|||
return fmt.Sprintf("failed to get profile, %s", e.Profile) |
|||
} |
|||
|
|||
// OrigErr is the underlying error that caused the failure.
|
|||
func (e SharedConfigProfileNotExistsError) OrigErr() error { |
|||
return e.Err |
|||
} |
|||
|
|||
// Error satisfies the error interface.
|
|||
func (e SharedConfigProfileNotExistsError) Error() string { |
|||
return awserr.SprintError(e.Code(), e.Message(), "", e.Err) |
|||
} |
|||
|
|||
// SharedConfigAssumeRoleError is an error for the shared config when the
|
|||
// profile contains assume role information, but that information is invalid
|
|||
// or not complete.
|
|||
type SharedConfigAssumeRoleError struct { |
|||
RoleARN string |
|||
} |
|||
|
|||
// Code is the short id of the error.
|
|||
func (e SharedConfigAssumeRoleError) Code() string { |
|||
return "SharedConfigAssumeRoleError" |
|||
} |
|||
|
|||
// Message is the description of the error
|
|||
func (e SharedConfigAssumeRoleError) Message() string { |
|||
return fmt.Sprintf("failed to load assume role for %s, source profile has no shared credentials", |
|||
e.RoleARN) |
|||
} |
|||
|
|||
// OrigErr is the underlying error that caused the failure.
|
|||
func (e SharedConfigAssumeRoleError) OrigErr() error { |
|||
return nil |
|||
} |
|||
|
|||
// Error satisfies the error interface.
|
|||
func (e SharedConfigAssumeRoleError) Error() string { |
|||
return awserr.SprintError(e.Code(), e.Message(), "", nil) |
|||
} |
@ -0,0 +1,82 @@ |
|||
package v4 |
|||
|
|||
import ( |
|||
"net/http" |
|||
"strings" |
|||
) |
|||
|
|||
// validator houses a set of rule needed for validation of a
|
|||
// string value
|
|||
type rules []rule |
|||
|
|||
// rule interface allows for more flexible rules and just simply
|
|||
// checks whether or not a value adheres to that rule
|
|||
type rule interface { |
|||
IsValid(value string) bool |
|||
} |
|||
|
|||
// IsValid will iterate through all rules and see if any rules
|
|||
// apply to the value and supports nested rules
|
|||
func (r rules) IsValid(value string) bool { |
|||
for _, rule := range r { |
|||
if rule.IsValid(value) { |
|||
return true |
|||
} |
|||
} |
|||
return false |
|||
} |
|||
|
|||
// mapRule generic rule for maps
|
|||
type mapRule map[string]struct{} |
|||
|
|||
// IsValid for the map rule satisfies whether it exists in the map
|
|||
func (m mapRule) IsValid(value string) bool { |
|||
_, ok := m[value] |
|||
return ok |
|||
} |
|||
|
|||
// whitelist is a generic rule for whitelisting
|
|||
type whitelist struct { |
|||
rule |
|||
} |
|||
|
|||
// IsValid for whitelist checks if the value is within the whitelist
|
|||
func (w whitelist) IsValid(value string) bool { |
|||
return w.rule.IsValid(value) |
|||
} |
|||
|
|||
// blacklist is a generic rule for blacklisting
|
|||
type blacklist struct { |
|||
rule |
|||
} |
|||
|
|||
// IsValid for whitelist checks if the value is within the whitelist
|
|||
func (b blacklist) IsValid(value string) bool { |
|||
return !b.rule.IsValid(value) |
|||
} |
|||
|
|||
type patterns []string |
|||
|
|||
// IsValid for patterns checks each pattern and returns if a match has
|
|||
// been found
|
|||
func (p patterns) IsValid(value string) bool { |
|||
for _, pattern := range p { |
|||
if strings.HasPrefix(http.CanonicalHeaderKey(value), pattern) { |
|||
return true |
|||
} |
|||
} |
|||
return false |
|||
} |
|||
|
|||
// inclusiveRules rules allow for rules to depend on one another
|
|||
type inclusiveRules []rule |
|||
|
|||
// IsValid will return true if all rules are true
|
|||
func (r inclusiveRules) IsValid(value string) bool { |
|||
for _, rule := range r { |
|||
if !rule.IsValid(value) { |
|||
return false |
|||
} |
|||
} |
|||
return true |
|||
} |
@ -0,0 +1,662 @@ |
|||
// Package v4 implements signing for AWS V4 signer
|
|||
//
|
|||
// Provides request signing for request that need to be signed with
|
|||
// AWS V4 Signatures.
|
|||
package v4 |
|||
|
|||
import ( |
|||
"bytes" |
|||
"crypto/hmac" |
|||
"crypto/sha256" |
|||
"encoding/hex" |
|||
"fmt" |
|||
"io" |
|||
"io/ioutil" |
|||
"net/http" |
|||
"net/url" |
|||
"sort" |
|||
"strconv" |
|||
"strings" |
|||
"time" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws" |
|||
"github.com/aws/aws-sdk-go/aws/credentials" |
|||
"github.com/aws/aws-sdk-go/aws/request" |
|||
"github.com/aws/aws-sdk-go/private/protocol/rest" |
|||
) |
|||
|
|||
const ( |
|||
authHeaderPrefix = "AWS4-HMAC-SHA256" |
|||
timeFormat = "20060102T150405Z" |
|||
shortTimeFormat = "20060102" |
|||
|
|||
// emptyStringSHA256 is a SHA256 of an empty string
|
|||
emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855` |
|||
) |
|||
|
|||
var ignoredHeaders = rules{ |
|||
blacklist{ |
|||
mapRule{ |
|||
"Authorization": struct{}{}, |
|||
"User-Agent": struct{}{}, |
|||
}, |
|||
}, |
|||
} |
|||
|
|||
// requiredSignedHeaders is a whitelist for build canonical headers.
|
|||
var requiredSignedHeaders = rules{ |
|||
whitelist{ |
|||
mapRule{ |
|||
"Cache-Control": struct{}{}, |
|||
"Content-Disposition": struct{}{}, |
|||
"Content-Encoding": struct{}{}, |
|||
"Content-Language": struct{}{}, |
|||
"Content-Md5": struct{}{}, |
|||
"Content-Type": struct{}{}, |
|||
"Expires": struct{}{}, |
|||
"If-Match": struct{}{}, |
|||
"If-Modified-Since": struct{}{}, |
|||
"If-None-Match": struct{}{}, |
|||
"If-Unmodified-Since": struct{}{}, |
|||
"Range": struct{}{}, |
|||
"X-Amz-Acl": struct{}{}, |
|||
"X-Amz-Copy-Source": struct{}{}, |
|||
"X-Amz-Copy-Source-If-Match": struct{}{}, |
|||
"X-Amz-Copy-Source-If-Modified-Since": struct{}{}, |
|||
"X-Amz-Copy-Source-If-None-Match": struct{}{}, |
|||
"X-Amz-Copy-Source-If-Unmodified-Since": struct{}{}, |
|||
"X-Amz-Copy-Source-Range": struct{}{}, |
|||
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Algorithm": struct{}{}, |
|||
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key": struct{}{}, |
|||
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key-Md5": struct{}{}, |
|||
"X-Amz-Grant-Full-control": struct{}{}, |
|||
"X-Amz-Grant-Read": struct{}{}, |
|||
"X-Amz-Grant-Read-Acp": struct{}{}, |
|||
"X-Amz-Grant-Write": struct{}{}, |
|||
"X-Amz-Grant-Write-Acp": struct{}{}, |
|||
"X-Amz-Metadata-Directive": struct{}{}, |
|||
"X-Amz-Mfa": struct{}{}, |
|||
"X-Amz-Request-Payer": struct{}{}, |
|||
"X-Amz-Server-Side-Encryption": struct{}{}, |
|||
"X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id": struct{}{}, |
|||
"X-Amz-Server-Side-Encryption-Customer-Algorithm": struct{}{}, |
|||
"X-Amz-Server-Side-Encryption-Customer-Key": struct{}{}, |
|||
"X-Amz-Server-Side-Encryption-Customer-Key-Md5": struct{}{}, |
|||
"X-Amz-Storage-Class": struct{}{}, |
|||
"X-Amz-Website-Redirect-Location": struct{}{}, |
|||
}, |
|||
}, |
|||
patterns{"X-Amz-Meta-"}, |
|||
} |
|||
|
|||
// allowedHoisting is a whitelist for build query headers. The boolean value
|
|||
// represents whether or not it is a pattern.
|
|||
var allowedQueryHoisting = inclusiveRules{ |
|||
blacklist{requiredSignedHeaders}, |
|||
patterns{"X-Amz-"}, |
|||
} |
|||
|
|||
// Signer applies AWS v4 signing to given request. Use this to sign requests
|
|||
// that need to be signed with AWS V4 Signatures.
|
|||
type Signer struct { |
|||
// The authentication credentials the request will be signed against.
|
|||
// This value must be set to sign requests.
|
|||
Credentials *credentials.Credentials |
|||
|
|||
// Sets the log level the signer should use when reporting information to
|
|||
// the logger. If the logger is nil nothing will be logged. See
|
|||
// aws.LogLevelType for more information on available logging levels
|
|||
//
|
|||
// By default nothing will be logged.
|
|||
Debug aws.LogLevelType |
|||
|
|||
// The logger loging information will be written to. If there the logger
|
|||
// is nil, nothing will be logged.
|
|||
Logger aws.Logger |
|||
|
|||
// Disables the Signer's moving HTTP header key/value pairs from the HTTP
|
|||
// request header to the request's query string. This is most commonly used
|
|||
// with pre-signed requests preventing headers from being added to the
|
|||
// request's query string.
|
|||
DisableHeaderHoisting bool |
|||
|
|||
// currentTimeFn returns the time value which represents the current time.
|
|||
// This value should only be used for testing. If it is nil the default
|
|||
// time.Now will be used.
|
|||
currentTimeFn func() time.Time |
|||
} |
|||
|
|||
// NewSigner returns a Signer pointer configured with the credentials and optional
|
|||
// option values provided. If not options are provided the Signer will use its
|
|||
// default configuration.
|
|||
func NewSigner(credentials *credentials.Credentials, options ...func(*Signer)) *Signer { |
|||
v4 := &Signer{ |
|||
Credentials: credentials, |
|||
} |
|||
|
|||
for _, option := range options { |
|||
option(v4) |
|||
} |
|||
|
|||
return v4 |
|||
} |
|||
|
|||
type signingCtx struct { |
|||
ServiceName string |
|||
Region string |
|||
Request *http.Request |
|||
Body io.ReadSeeker |
|||
Query url.Values |
|||
Time time.Time |
|||
ExpireTime time.Duration |
|||
SignedHeaderVals http.Header |
|||
|
|||
credValues credentials.Value |
|||
isPresign bool |
|||
formattedTime string |
|||
formattedShortTime string |
|||
|
|||
bodyDigest string |
|||
signedHeaders string |
|||
canonicalHeaders string |
|||
canonicalString string |
|||
credentialString string |
|||
stringToSign string |
|||
signature string |
|||
authorization string |
|||
} |
|||
|
|||
// Sign signs AWS v4 requests with the provided body, service name, region the
|
|||
// request is made to, and time the request is signed at. The signTime allows
|
|||
// you to specify that a request is signed for the future, and cannot be
|
|||
// used until then.
|
|||
//
|
|||
// Returns a list of HTTP headers that were included in the signature or an
|
|||
// error if signing the request failed. Generally for signed requests this value
|
|||
// is not needed as the full request context will be captured by the http.Request
|
|||
// value. It is included for reference though.
|
|||
//
|
|||
// Sign will set the request's Body to be the `body` parameter passed in. If
|
|||
// the body is not already an io.ReadCloser, it will be wrapped within one. If
|
|||
// a `nil` body parameter passed to Sign, the request's Body field will be
|
|||
// also set to nil. Its important to note that this functionality will not
|
|||
// change the request's ContentLength of the request.
|
|||
//
|
|||
// Sign differs from Presign in that it will sign the request using HTTP
|
|||
// header values. This type of signing is intended for http.Request values that
|
|||
// will not be shared, or are shared in a way the header values on the request
|
|||
// will not be lost.
|
|||
//
|
|||
// The requests body is an io.ReadSeeker so the SHA256 of the body can be
|
|||
// generated. To bypass the signer computing the hash you can set the
|
|||
// "X-Amz-Content-Sha256" header with a precomputed value. The signer will
|
|||
// only compute the hash if the request header value is empty.
|
|||
func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) { |
|||
return v4.signWithBody(r, body, service, region, 0, signTime) |
|||
} |
|||
|
|||
// Presign signs AWS v4 requests with the provided body, service name, region
|
|||
// the request is made to, and time the request is signed at. The signTime
|
|||
// allows you to specify that a request is signed for the future, and cannot
|
|||
// be used until then.
|
|||
//
|
|||
// Returns a list of HTTP headers that were included in the signature or an
|
|||
// error if signing the request failed. For presigned requests these headers
|
|||
// and their values must be included on the HTTP request when it is made. This
|
|||
// is helpful to know what header values need to be shared with the party the
|
|||
// presigned request will be distributed to.
|
|||
//
|
|||
// Presign differs from Sign in that it will sign the request using query string
|
|||
// instead of header values. This allows you to share the Presigned Request's
|
|||
// URL with third parties, or distribute it throughout your system with minimal
|
|||
// dependencies.
|
|||
//
|
|||
// Presign also takes an exp value which is the duration the
|
|||
// signed request will be valid after the signing time. This is allows you to
|
|||
// set when the request will expire.
|
|||
//
|
|||
// The requests body is an io.ReadSeeker so the SHA256 of the body can be
|
|||
// generated. To bypass the signer computing the hash you can set the
|
|||
// "X-Amz-Content-Sha256" header with a precomputed value. The signer will
|
|||
// only compute the hash if the request header value is empty.
|
|||
//
|
|||
// Presigning a S3 request will not compute the body's SHA256 hash by default.
|
|||
// This is done due to the general use case for S3 presigned URLs is to share
|
|||
// PUT/GET capabilities. If you would like to include the body's SHA256 in the
|
|||
// presigned request's signature you can set the "X-Amz-Content-Sha256"
|
|||
// HTTP header and that will be included in the request's signature.
|
|||
func (v4 Signer) Presign(r *http.Request, body io.ReadSeeker, service, region string, exp time.Duration, signTime time.Time) (http.Header, error) { |
|||
return v4.signWithBody(r, body, service, region, exp, signTime) |
|||
} |
|||
|
|||
func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, region string, exp time.Duration, signTime time.Time) (http.Header, error) { |
|||
currentTimeFn := v4.currentTimeFn |
|||
if currentTimeFn == nil { |
|||
currentTimeFn = time.Now |
|||
} |
|||
|
|||
ctx := &signingCtx{ |
|||
Request: r, |
|||
Body: body, |
|||
Query: r.URL.Query(), |
|||
Time: signTime, |
|||
ExpireTime: exp, |
|||
isPresign: exp != 0, |
|||
ServiceName: service, |
|||
Region: region, |
|||
} |
|||
|
|||
if ctx.isRequestSigned() { |
|||
ctx.Time = currentTimeFn() |
|||
ctx.handlePresignRemoval() |
|||
} |
|||
|
|||
var err error |
|||
ctx.credValues, err = v4.Credentials.Get() |
|||
if err != nil { |
|||
return http.Header{}, err |
|||
} |
|||
|
|||
ctx.assignAmzQueryValues() |
|||
ctx.build(v4.DisableHeaderHoisting) |
|||
|
|||
// If the request is not presigned the body should be attached to it. This
|
|||
// prevents the confusion of wanting to send a signed request without
|
|||
// the body the request was signed for attached.
|
|||
if !ctx.isPresign { |
|||
var reader io.ReadCloser |
|||
if body != nil { |
|||
var ok bool |
|||
if reader, ok = body.(io.ReadCloser); !ok { |
|||
reader = ioutil.NopCloser(body) |
|||
} |
|||
} |
|||
r.Body = reader |
|||
} |
|||
|
|||
if v4.Debug.Matches(aws.LogDebugWithSigning) { |
|||
v4.logSigningInfo(ctx) |
|||
} |
|||
|
|||
return ctx.SignedHeaderVals, nil |
|||
} |
|||
|
|||
func (ctx *signingCtx) handlePresignRemoval() { |
|||
if !ctx.isPresign { |
|||
return |
|||
} |
|||
|
|||
// The credentials have expired for this request. The current signing
|
|||
// is invalid, and needs to be request because the request will fail.
|
|||
ctx.removePresign() |
|||
|
|||
// Update the request's query string to ensure the values stays in
|
|||
// sync in the case retrieving the new credentials fails.
|
|||
ctx.Request.URL.RawQuery = ctx.Query.Encode() |
|||
} |
|||
|
|||
func (ctx *signingCtx) assignAmzQueryValues() { |
|||
if ctx.isPresign { |
|||
ctx.Query.Set("X-Amz-Algorithm", authHeaderPrefix) |
|||
if ctx.credValues.SessionToken != "" { |
|||
ctx.Query.Set("X-Amz-Security-Token", ctx.credValues.SessionToken) |
|||
} else { |
|||
ctx.Query.Del("X-Amz-Security-Token") |
|||
} |
|||
|
|||
return |
|||
} |
|||
|
|||
if ctx.credValues.SessionToken != "" { |
|||
ctx.Request.Header.Set("X-Amz-Security-Token", ctx.credValues.SessionToken) |
|||
} |
|||
} |
|||
|
|||
// SignRequestHandler is a named request handler the SDK will use to sign
|
|||
// service client request with using the V4 signature.
|
|||
var SignRequestHandler = request.NamedHandler{ |
|||
Name: "v4.SignRequestHandler", Fn: SignSDKRequest, |
|||
} |
|||
|
|||
// SignSDKRequest signs an AWS request with the V4 signature. This
|
|||
// request handler is bested used only with the SDK's built in service client's
|
|||
// API operation requests.
|
|||
//
|
|||
// This function should not be used on its on its own, but in conjunction with
|
|||
// an AWS service client's API operation call. To sign a standalone request
|
|||
// not created by a service client's API operation method use the "Sign" or
|
|||
// "Presign" functions of the "Signer" type.
|
|||
//
|
|||
// If the credentials of the request's config are set to
|
|||
// credentials.AnonymousCredentials the request will not be signed.
|
|||
func SignSDKRequest(req *request.Request) { |
|||
signSDKRequestWithCurrTime(req, time.Now) |
|||
} |
|||
func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time) { |
|||
// If the request does not need to be signed ignore the signing of the
|
|||
// request if the AnonymousCredentials object is used.
|
|||
if req.Config.Credentials == credentials.AnonymousCredentials { |
|||
return |
|||
} |
|||
|
|||
region := req.ClientInfo.SigningRegion |
|||
if region == "" { |
|||
region = aws.StringValue(req.Config.Region) |
|||
} |
|||
|
|||
name := req.ClientInfo.SigningName |
|||
if name == "" { |
|||
name = req.ClientInfo.ServiceName |
|||
} |
|||
|
|||
v4 := NewSigner(req.Config.Credentials, func(v4 *Signer) { |
|||
v4.Debug = req.Config.LogLevel.Value() |
|||
v4.Logger = req.Config.Logger |
|||
v4.DisableHeaderHoisting = req.NotHoist |
|||
v4.currentTimeFn = curTimeFn |
|||
}) |
|||
|
|||
signingTime := req.Time |
|||
if !req.LastSignedAt.IsZero() { |
|||
signingTime = req.LastSignedAt |
|||
} |
|||
|
|||
signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.GetBody(), |
|||
name, region, req.ExpireTime, signingTime, |
|||
) |
|||
if err != nil { |
|||
req.Error = err |
|||
req.SignedHeaderVals = nil |
|||
return |
|||
} |
|||
|
|||
req.SignedHeaderVals = signedHeaders |
|||
req.LastSignedAt = curTimeFn() |
|||
} |
|||
|
|||
const logSignInfoMsg = `DEBUG: Request Signature: |
|||
---[ CANONICAL STRING ]----------------------------- |
|||
%s |
|||
---[ STRING TO SIGN ]-------------------------------- |
|||
%s%s |
|||
-----------------------------------------------------` |
|||
const logSignedURLMsg = ` |
|||
---[ SIGNED URL ]------------------------------------ |
|||
%s` |
|||
|
|||
func (v4 *Signer) logSigningInfo(ctx *signingCtx) { |
|||
signedURLMsg := "" |
|||
if ctx.isPresign { |
|||
signedURLMsg = fmt.Sprintf(logSignedURLMsg, ctx.Request.URL.String()) |
|||
} |
|||
msg := fmt.Sprintf(logSignInfoMsg, ctx.canonicalString, ctx.stringToSign, signedURLMsg) |
|||
v4.Logger.Log(msg) |
|||
} |
|||
|
|||
func (ctx *signingCtx) build(disableHeaderHoisting bool) { |
|||
ctx.buildTime() // no depends
|
|||
ctx.buildCredentialString() // no depends
|
|||
|
|||
unsignedHeaders := ctx.Request.Header |
|||
if ctx.isPresign { |
|||
if !disableHeaderHoisting { |
|||
urlValues := url.Values{} |
|||
urlValues, unsignedHeaders = buildQuery(allowedQueryHoisting, unsignedHeaders) // no depends
|
|||
for k := range urlValues { |
|||
ctx.Query[k] = urlValues[k] |
|||
} |
|||
} |
|||
} |
|||
|
|||
ctx.buildBodyDigest() |
|||
ctx.buildCanonicalHeaders(ignoredHeaders, unsignedHeaders) |
|||
ctx.buildCanonicalString() // depends on canon headers / signed headers
|
|||
ctx.buildStringToSign() // depends on canon string
|
|||
ctx.buildSignature() // depends on string to sign
|
|||
|
|||
if ctx.isPresign { |
|||
ctx.Request.URL.RawQuery += "&X-Amz-Signature=" + ctx.signature |
|||
} else { |
|||
parts := []string{ |
|||
authHeaderPrefix + " Credential=" + ctx.credValues.AccessKeyID + "/" + ctx.credentialString, |
|||
"SignedHeaders=" + ctx.signedHeaders, |
|||
"Signature=" + ctx.signature, |
|||
} |
|||
ctx.Request.Header.Set("Authorization", strings.Join(parts, ", ")) |
|||
} |
|||
} |
|||
|
|||
func (ctx *signingCtx) buildTime() { |
|||
ctx.formattedTime = ctx.Time.UTC().Format(timeFormat) |
|||
ctx.formattedShortTime = ctx.Time.UTC().Format(shortTimeFormat) |
|||
|
|||
if ctx.isPresign { |
|||
duration := int64(ctx.ExpireTime / time.Second) |
|||
ctx.Query.Set("X-Amz-Date", ctx.formattedTime) |
|||
ctx.Query.Set("X-Amz-Expires", strconv.FormatInt(duration, 10)) |
|||
} else { |
|||
ctx.Request.Header.Set("X-Amz-Date", ctx.formattedTime) |
|||
} |
|||
} |
|||
|
|||
func (ctx *signingCtx) buildCredentialString() { |
|||
ctx.credentialString = strings.Join([]string{ |
|||
ctx.formattedShortTime, |
|||
ctx.Region, |
|||
ctx.ServiceName, |
|||
"aws4_request", |
|||
}, "/") |
|||
|
|||
if ctx.isPresign { |
|||
ctx.Query.Set("X-Amz-Credential", ctx.credValues.AccessKeyID+"/"+ctx.credentialString) |
|||
} |
|||
} |
|||
|
|||
func buildQuery(r rule, header http.Header) (url.Values, http.Header) { |
|||
query := url.Values{} |
|||
unsignedHeaders := http.Header{} |
|||
for k, h := range header { |
|||
if r.IsValid(k) { |
|||
query[k] = h |
|||
} else { |
|||
unsignedHeaders[k] = h |
|||
} |
|||
} |
|||
|
|||
return query, unsignedHeaders |
|||
} |
|||
func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) { |
|||
var headers []string |
|||
headers = append(headers, "host") |
|||
for k, v := range header { |
|||
canonicalKey := http.CanonicalHeaderKey(k) |
|||
if !r.IsValid(canonicalKey) { |
|||
continue // ignored header
|
|||
} |
|||
if ctx.SignedHeaderVals == nil { |
|||
ctx.SignedHeaderVals = make(http.Header) |
|||
} |
|||
|
|||
lowerCaseKey := strings.ToLower(k) |
|||
if _, ok := ctx.SignedHeaderVals[lowerCaseKey]; ok { |
|||
// include additional values
|
|||
ctx.SignedHeaderVals[lowerCaseKey] = append(ctx.SignedHeaderVals[lowerCaseKey], v...) |
|||
continue |
|||
} |
|||
|
|||
headers = append(headers, lowerCaseKey) |
|||
ctx.SignedHeaderVals[lowerCaseKey] = v |
|||
} |
|||
sort.Strings(headers) |
|||
|
|||
ctx.signedHeaders = strings.Join(headers, ";") |
|||
|
|||
if ctx.isPresign { |
|||
ctx.Query.Set("X-Amz-SignedHeaders", ctx.signedHeaders) |
|||
} |
|||
|
|||
headerValues := make([]string, len(headers)) |
|||
for i, k := range headers { |
|||
if k == "host" { |
|||
headerValues[i] = "host:" + ctx.Request.URL.Host |
|||
} else { |
|||
headerValues[i] = k + ":" + |
|||
strings.Join(ctx.SignedHeaderVals[k], ",") |
|||
} |
|||
} |
|||
|
|||
ctx.canonicalHeaders = strings.Join(stripExcessSpaces(headerValues), "\n") |
|||
} |
|||
|
|||
func (ctx *signingCtx) buildCanonicalString() { |
|||
ctx.Request.URL.RawQuery = strings.Replace(ctx.Query.Encode(), "+", "%20", -1) |
|||
uri := ctx.Request.URL.Opaque |
|||
if uri != "" { |
|||
uri = "/" + strings.Join(strings.Split(uri, "/")[3:], "/") |
|||
} else { |
|||
uri = ctx.Request.URL.Path |
|||
} |
|||
if uri == "" { |
|||
uri = "/" |
|||
} |
|||
|
|||
if ctx.ServiceName != "s3" { |
|||
uri = rest.EscapePath(uri, false) |
|||
} |
|||
|
|||
ctx.canonicalString = strings.Join([]string{ |
|||
ctx.Request.Method, |
|||
uri, |
|||
ctx.Request.URL.RawQuery, |
|||
ctx.canonicalHeaders + "\n", |
|||
ctx.signedHeaders, |
|||
ctx.bodyDigest, |
|||
}, "\n") |
|||
} |
|||
|
|||
func (ctx *signingCtx) buildStringToSign() { |
|||
ctx.stringToSign = strings.Join([]string{ |
|||
authHeaderPrefix, |
|||
ctx.formattedTime, |
|||
ctx.credentialString, |
|||
hex.EncodeToString(makeSha256([]byte(ctx.canonicalString))), |
|||
}, "\n") |
|||
} |
|||
|
|||
func (ctx *signingCtx) buildSignature() { |
|||
secret := ctx.credValues.SecretAccessKey |
|||
date := makeHmac([]byte("AWS4"+secret), []byte(ctx.formattedShortTime)) |
|||
region := makeHmac(date, []byte(ctx.Region)) |
|||
service := makeHmac(region, []byte(ctx.ServiceName)) |
|||
credentials := makeHmac(service, []byte("aws4_request")) |
|||
signature := makeHmac(credentials, []byte(ctx.stringToSign)) |
|||
ctx.signature = hex.EncodeToString(signature) |
|||
} |
|||
|
|||
func (ctx *signingCtx) buildBodyDigest() { |
|||
hash := ctx.Request.Header.Get("X-Amz-Content-Sha256") |
|||
if hash == "" { |
|||
if ctx.isPresign && ctx.ServiceName == "s3" { |
|||
hash = "UNSIGNED-PAYLOAD" |
|||
} else if ctx.Body == nil { |
|||
hash = emptyStringSHA256 |
|||
} else { |
|||
hash = hex.EncodeToString(makeSha256Reader(ctx.Body)) |
|||
} |
|||
if ctx.ServiceName == "s3" || ctx.ServiceName == "glacier" { |
|||
ctx.Request.Header.Set("X-Amz-Content-Sha256", hash) |
|||
} |
|||
} |
|||
ctx.bodyDigest = hash |
|||
} |
|||
|
|||
// isRequestSigned returns if the request is currently signed or presigned
|
|||
func (ctx *signingCtx) isRequestSigned() bool { |
|||
if ctx.isPresign && ctx.Query.Get("X-Amz-Signature") != "" { |
|||
return true |
|||
} |
|||
if ctx.Request.Header.Get("Authorization") != "" { |
|||
return true |
|||
} |
|||
|
|||
return false |
|||
} |
|||
|
|||
// unsign removes signing flags for both signed and presigned requests.
|
|||
func (ctx *signingCtx) removePresign() { |
|||
ctx.Query.Del("X-Amz-Algorithm") |
|||
ctx.Query.Del("X-Amz-Signature") |
|||
ctx.Query.Del("X-Amz-Security-Token") |
|||
ctx.Query.Del("X-Amz-Date") |
|||
ctx.Query.Del("X-Amz-Expires") |
|||
ctx.Query.Del("X-Amz-Credential") |
|||
ctx.Query.Del("X-Amz-SignedHeaders") |
|||
} |
|||
|
|||
func makeHmac(key []byte, data []byte) []byte { |
|||
hash := hmac.New(sha256.New, key) |
|||
hash.Write(data) |
|||
return hash.Sum(nil) |
|||
} |
|||
|
|||
func makeSha256(data []byte) []byte { |
|||
hash := sha256.New() |
|||
hash.Write(data) |
|||
return hash.Sum(nil) |
|||
} |
|||
|
|||
func makeSha256Reader(reader io.ReadSeeker) []byte { |
|||
hash := sha256.New() |
|||
start, _ := reader.Seek(0, 1) |
|||
defer reader.Seek(start, 0) |
|||
|
|||
io.Copy(hash, reader) |
|||
return hash.Sum(nil) |
|||
} |
|||
|
|||
const doubleSpaces = " " |
|||
|
|||
var doubleSpaceBytes = []byte(doubleSpaces) |
|||
|
|||
func stripExcessSpaces(headerVals []string) []string { |
|||
vals := make([]string, len(headerVals)) |
|||
for i, str := range headerVals { |
|||
// Trim leading and trailing spaces
|
|||
trimmed := strings.TrimSpace(str) |
|||
|
|||
idx := strings.Index(trimmed, doubleSpaces) |
|||
var buf []byte |
|||
for idx > -1 { |
|||
// Multiple adjacent spaces found
|
|||
if buf == nil { |
|||
// first time create the buffer
|
|||
buf = []byte(trimmed) |
|||
} |
|||
|
|||
stripToIdx := -1 |
|||
for j := idx + 1; j < len(buf); j++ { |
|||
if buf[j] != ' ' { |
|||
buf = append(buf[:idx+1], buf[j:]...) |
|||
stripToIdx = j |
|||
break |
|||
} |
|||
} |
|||
|
|||
if stripToIdx >= 0 { |
|||
idx = bytes.Index(buf[stripToIdx:], doubleSpaceBytes) |
|||
if idx >= 0 { |
|||
idx += stripToIdx |
|||
} |
|||
} else { |
|||
idx = -1 |
|||
} |
|||
} |
|||
|
|||
if buf != nil { |
|||
vals[i] = string(buf) |
|||
} else { |
|||
vals[i] = trimmed |
|||
} |
|||
} |
|||
return vals |
|||
} |
@ -0,0 +1,106 @@ |
|||
package aws |
|||
|
|||
import ( |
|||
"io" |
|||
"sync" |
|||
) |
|||
|
|||
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser
|
|||
func ReadSeekCloser(r io.Reader) ReaderSeekerCloser { |
|||
return ReaderSeekerCloser{r} |
|||
} |
|||
|
|||
// ReaderSeekerCloser represents a reader that can also delegate io.Seeker and
|
|||
// io.Closer interfaces to the underlying object if they are available.
|
|||
type ReaderSeekerCloser struct { |
|||
r io.Reader |
|||
} |
|||
|
|||
// Read reads from the reader up to size of p. The number of bytes read, and
|
|||
// error if it occurred will be returned.
|
|||
//
|
|||
// If the reader is not an io.Reader zero bytes read, and nil error will be returned.
|
|||
//
|
|||
// Performs the same functionality as io.Reader Read
|
|||
func (r ReaderSeekerCloser) Read(p []byte) (int, error) { |
|||
switch t := r.r.(type) { |
|||
case io.Reader: |
|||
return t.Read(p) |
|||
} |
|||
return 0, nil |
|||
} |
|||
|
|||
// Seek sets the offset for the next Read to offset, interpreted according to
|
|||
// whence: 0 means relative to the origin of the file, 1 means relative to the
|
|||
// current offset, and 2 means relative to the end. Seek returns the new offset
|
|||
// and an error, if any.
|
|||
//
|
|||
// If the ReaderSeekerCloser is not an io.Seeker nothing will be done.
|
|||
func (r ReaderSeekerCloser) Seek(offset int64, whence int) (int64, error) { |
|||
switch t := r.r.(type) { |
|||
case io.Seeker: |
|||
return t.Seek(offset, whence) |
|||
} |
|||
return int64(0), nil |
|||
} |
|||
|
|||
// Close closes the ReaderSeekerCloser.
|
|||
//
|
|||
// If the ReaderSeekerCloser is not an io.Closer nothing will be done.
|
|||
func (r ReaderSeekerCloser) Close() error { |
|||
switch t := r.r.(type) { |
|||
case io.Closer: |
|||
return t.Close() |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// A WriteAtBuffer provides a in memory buffer supporting the io.WriterAt interface
|
|||
// Can be used with the s3manager.Downloader to download content to a buffer
|
|||
// in memory. Safe to use concurrently.
|
|||
type WriteAtBuffer struct { |
|||
buf []byte |
|||
m sync.Mutex |
|||
|
|||
// GrowthCoeff defines the growth rate of the internal buffer. By
|
|||
// default, the growth rate is 1, where expanding the internal
|
|||
// buffer will allocate only enough capacity to fit the new expected
|
|||
// length.
|
|||
GrowthCoeff float64 |
|||
} |
|||
|
|||
// NewWriteAtBuffer creates a WriteAtBuffer with an internal buffer
|
|||
// provided by buf.
|
|||
func NewWriteAtBuffer(buf []byte) *WriteAtBuffer { |
|||
return &WriteAtBuffer{buf: buf} |
|||
} |
|||
|
|||
// WriteAt writes a slice of bytes to a buffer starting at the position provided
|
|||
// The number of bytes written will be returned, or error. Can overwrite previous
|
|||
// written slices if the write ats overlap.
|
|||
func (b *WriteAtBuffer) WriteAt(p []byte, pos int64) (n int, err error) { |
|||
pLen := len(p) |
|||
expLen := pos + int64(pLen) |
|||
b.m.Lock() |
|||
defer b.m.Unlock() |
|||
if int64(len(b.buf)) < expLen { |
|||
if int64(cap(b.buf)) < expLen { |
|||
if b.GrowthCoeff < 1 { |
|||
b.GrowthCoeff = 1 |
|||
} |
|||
newBuf := make([]byte, expLen, int64(b.GrowthCoeff*float64(expLen))) |
|||
copy(newBuf, b.buf) |
|||
b.buf = newBuf |
|||
} |
|||
b.buf = b.buf[:expLen] |
|||
} |
|||
copy(b.buf[pos:], p) |
|||
return pLen, nil |
|||
} |
|||
|
|||
// Bytes returns a slice of bytes written to the buffer.
|
|||
func (b *WriteAtBuffer) Bytes() []byte { |
|||
b.m.Lock() |
|||
defer b.m.Unlock() |
|||
return b.buf[:len(b.buf):len(b.buf)] |
|||
} |
@ -0,0 +1,8 @@ |
|||
// Package aws provides core functionality for making requests to AWS services.
|
|||
package aws |
|||
|
|||
// SDKName is the name of this AWS SDK
|
|||
const SDKName = "aws-sdk-go" |
|||
|
|||
// SDKVersion is the version of this SDK
|
|||
const SDKVersion = "1.4.14" |
@ -0,0 +1,70 @@ |
|||
// Package endpoints validates regional endpoints for services.
|
|||
package endpoints |
|||
|
|||
//go:generate go run ../model/cli/gen-endpoints/main.go endpoints.json endpoints_map.go
|
|||
//go:generate gofmt -s -w endpoints_map.go
|
|||
|
|||
import ( |
|||
"fmt" |
|||
"regexp" |
|||
"strings" |
|||
) |
|||
|
|||
// NormalizeEndpoint takes and endpoint and service API information to return a
|
|||
// normalized endpoint and signing region. If the endpoint is not an empty string
|
|||
// the service name and region will be used to look up the service's API endpoint.
|
|||
// If the endpoint is provided the scheme will be added if it is not present.
|
|||
func NormalizeEndpoint(endpoint, serviceName, region string, disableSSL, useDualStack bool) (normEndpoint, signingRegion string) { |
|||
if endpoint == "" { |
|||
return EndpointForRegion(serviceName, region, disableSSL, useDualStack) |
|||
} |
|||
|
|||
return AddScheme(endpoint, disableSSL), "" |
|||
} |
|||
|
|||
// EndpointForRegion returns an endpoint and its signing region for a service and region.
|
|||
// if the service and region pair are not found endpoint and signingRegion will be empty.
|
|||
func EndpointForRegion(svcName, region string, disableSSL, useDualStack bool) (endpoint, signingRegion string) { |
|||
dualStackField := "" |
|||
if useDualStack { |
|||
dualStackField = "/dualstack" |
|||
} |
|||
|
|||
derivedKeys := []string{ |
|||
region + "/" + svcName + dualStackField, |
|||
region + "/*" + dualStackField, |
|||
"*/" + svcName + dualStackField, |
|||
"*/*" + dualStackField, |
|||
} |
|||
|
|||
for _, key := range derivedKeys { |
|||
if val, ok := endpointsMap.Endpoints[key]; ok { |
|||
ep := val.Endpoint |
|||
ep = strings.Replace(ep, "{region}", region, -1) |
|||
ep = strings.Replace(ep, "{service}", svcName, -1) |
|||
|
|||
endpoint = ep |
|||
signingRegion = val.SigningRegion |
|||
break |
|||
} |
|||
} |
|||
|
|||
return AddScheme(endpoint, disableSSL), signingRegion |
|||
} |
|||
|
|||
// Regular expression to determine if the endpoint string is prefixed with a scheme.
|
|||
var schemeRE = regexp.MustCompile("^([^:]+)://") |
|||
|
|||
// AddScheme adds the HTTP or HTTPS schemes to a endpoint URL if there is no
|
|||
// scheme. If disableSSL is true HTTP will be added instead of the default HTTPS.
|
|||
func AddScheme(endpoint string, disableSSL bool) string { |
|||
if endpoint != "" && !schemeRE.MatchString(endpoint) { |
|||
scheme := "https" |
|||
if disableSSL { |
|||
scheme = "http" |
|||
} |
|||
endpoint = fmt.Sprintf("%s://%s", scheme, endpoint) |
|||
} |
|||
|
|||
return endpoint |
|||
} |
@ -0,0 +1,78 @@ |
|||
{ |
|||
"version": 2, |
|||
"endpoints": { |
|||
"*/*": { |
|||
"endpoint": "{service}.{region}.amazonaws.com" |
|||
}, |
|||
"cn-north-1/*": { |
|||
"endpoint": "{service}.{region}.amazonaws.com.cn", |
|||
"signatureVersion": "v4" |
|||
}, |
|||
"cn-north-1/ec2metadata": { |
|||
"endpoint": "http://169.254.169.254/latest" |
|||
}, |
|||
"us-gov-west-1/iam": { |
|||
"endpoint": "iam.us-gov.amazonaws.com" |
|||
}, |
|||
"us-gov-west-1/sts": { |
|||
"endpoint": "sts.us-gov-west-1.amazonaws.com" |
|||
}, |
|||
"us-gov-west-1/s3": { |
|||
"endpoint": "s3-{region}.amazonaws.com" |
|||
}, |
|||
"us-gov-west-1/ec2metadata": { |
|||
"endpoint": "http://169.254.169.254/latest" |
|||
}, |
|||
"*/cloudfront": { |
|||
"endpoint": "cloudfront.amazonaws.com", |
|||
"signingRegion": "us-east-1" |
|||
}, |
|||
"*/cloudsearchdomain": { |
|||
"endpoint": "", |
|||
"signingRegion": "us-east-1" |
|||
}, |
|||
"*/data.iot": { |
|||
"endpoint": "", |
|||
"signingRegion": "us-east-1" |
|||
}, |
|||
"*/ec2metadata": { |
|||
"endpoint": "http://169.254.169.254/latest" |
|||
}, |
|||
"*/iam": { |
|||
"endpoint": "iam.amazonaws.com", |
|||
"signingRegion": "us-east-1" |
|||
}, |
|||
"*/importexport": { |
|||
"endpoint": "importexport.amazonaws.com", |
|||
"signingRegion": "us-east-1" |
|||
}, |
|||
"*/route53": { |
|||
"endpoint": "route53.amazonaws.com", |
|||
"signingRegion": "us-east-1" |
|||
}, |
|||
"*/sts": { |
|||
"endpoint": "sts.amazonaws.com", |
|||
"signingRegion": "us-east-1" |
|||
}, |
|||
"*/waf": { |
|||
"endpoint": "waf.amazonaws.com", |
|||
"signingRegion": "us-east-1" |
|||
}, |
|||
"us-east-1/sdb": { |
|||
"endpoint": "sdb.amazonaws.com", |
|||
"signingRegion": "us-east-1" |
|||
}, |
|||
"*/s3": { |
|||
"endpoint": "s3-{region}.amazonaws.com" |
|||
}, |
|||
"*/s3/dualstack": { |
|||
"endpoint": "s3.dualstack.{region}.amazonaws.com" |
|||
}, |
|||
"us-east-1/s3": { |
|||
"endpoint": "s3.amazonaws.com" |
|||
}, |
|||
"eu-central-1/s3": { |
|||
"endpoint": "{service}.{region}.amazonaws.com" |
|||
} |
|||
} |
|||
} |
@ -0,0 +1,91 @@ |
|||
package endpoints |
|||
|
|||
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
|
|||
|
|||
type endpointStruct struct { |
|||
Version int |
|||
Endpoints map[string]endpointEntry |
|||
} |
|||
|
|||
type endpointEntry struct { |
|||
Endpoint string |
|||
SigningRegion string |
|||
} |
|||
|
|||
var endpointsMap = endpointStruct{ |
|||
Version: 2, |
|||
Endpoints: map[string]endpointEntry{ |
|||
"*/*": { |
|||
Endpoint: "{service}.{region}.amazonaws.com", |
|||
}, |
|||
"*/cloudfront": { |
|||
Endpoint: "cloudfront.amazonaws.com", |
|||
SigningRegion: "us-east-1", |
|||
}, |
|||
"*/cloudsearchdomain": { |
|||
Endpoint: "", |
|||
SigningRegion: "us-east-1", |
|||
}, |
|||
"*/data.iot": { |
|||
Endpoint: "", |
|||
SigningRegion: "us-east-1", |
|||
}, |
|||
"*/ec2metadata": { |
|||
Endpoint: "http://169.254.169.254/latest", |
|||
}, |
|||
"*/iam": { |
|||
Endpoint: "iam.amazonaws.com", |
|||
SigningRegion: "us-east-1", |
|||
}, |
|||
"*/importexport": { |
|||
Endpoint: "importexport.amazonaws.com", |
|||
SigningRegion: "us-east-1", |
|||
}, |
|||
"*/route53": { |
|||
Endpoint: "route53.amazonaws.com", |
|||
SigningRegion: "us-east-1", |
|||
}, |
|||
"*/s3": { |
|||
Endpoint: "s3-{region}.amazonaws.com", |
|||
}, |
|||
"*/s3/dualstack": { |
|||
Endpoint: "s3.dualstack.{region}.amazonaws.com", |
|||
}, |
|||
"*/sts": { |
|||
Endpoint: "sts.amazonaws.com", |
|||
SigningRegion: "us-east-1", |
|||
}, |
|||
"*/waf": { |
|||
Endpoint: "waf.amazonaws.com", |
|||
SigningRegion: "us-east-1", |
|||
}, |
|||
"cn-north-1/*": { |
|||
Endpoint: "{service}.{region}.amazonaws.com.cn", |
|||
}, |
|||
"cn-north-1/ec2metadata": { |
|||
Endpoint: "http://169.254.169.254/latest", |
|||
}, |
|||
"eu-central-1/s3": { |
|||
Endpoint: "{service}.{region}.amazonaws.com", |
|||
}, |
|||
"us-east-1/s3": { |
|||
Endpoint: "s3.amazonaws.com", |
|||
}, |
|||
"us-east-1/sdb": { |
|||
Endpoint: "sdb.amazonaws.com", |
|||
SigningRegion: "us-east-1", |
|||
}, |
|||
"us-gov-west-1/ec2metadata": { |
|||
Endpoint: "http://169.254.169.254/latest", |
|||
}, |
|||
"us-gov-west-1/iam": { |
|||
Endpoint: "iam.us-gov.amazonaws.com", |
|||
}, |
|||
"us-gov-west-1/s3": { |
|||
Endpoint: "s3-{region}.amazonaws.com", |
|||
}, |
|||
"us-gov-west-1/sts": { |
|||
Endpoint: "sts.us-gov-west-1.amazonaws.com", |
|||
}, |
|||
}, |
|||
} |
@ -0,0 +1,75 @@ |
|||
package protocol |
|||
|
|||
import ( |
|||
"crypto/rand" |
|||
"fmt" |
|||
"reflect" |
|||
) |
|||
|
|||
// RandReader is the random reader the protocol package will use to read
|
|||
// random bytes from. This is exported for testing, and should not be used.
|
|||
var RandReader = rand.Reader |
|||
|
|||
const idempotencyTokenFillTag = `idempotencyToken` |
|||
|
|||
// CanSetIdempotencyToken returns true if the struct field should be
|
|||
// automatically populated with a Idempotency token.
|
|||
//
|
|||
// Only *string and string type fields that are tagged with idempotencyToken
|
|||
// which are not already set can be auto filled.
|
|||
func CanSetIdempotencyToken(v reflect.Value, f reflect.StructField) bool { |
|||
switch u := v.Interface().(type) { |
|||
// To auto fill an Idempotency token the field must be a string,
|
|||
// tagged for auto fill, and have a zero value.
|
|||
case *string: |
|||
return u == nil && len(f.Tag.Get(idempotencyTokenFillTag)) != 0 |
|||
case string: |
|||
return len(u) == 0 && len(f.Tag.Get(idempotencyTokenFillTag)) != 0 |
|||
} |
|||
|
|||
return false |
|||
} |
|||
|
|||
// GetIdempotencyToken returns a randomly generated idempotency token.
|
|||
func GetIdempotencyToken() string { |
|||
b := make([]byte, 16) |
|||
RandReader.Read(b) |
|||
|
|||
return UUIDVersion4(b) |
|||
} |
|||
|
|||
// SetIdempotencyToken will set the value provided with a Idempotency Token.
|
|||
// Given that the value can be set. Will panic if value is not setable.
|
|||
func SetIdempotencyToken(v reflect.Value) { |
|||
if v.Kind() == reflect.Ptr { |
|||
if v.IsNil() && v.CanSet() { |
|||
v.Set(reflect.New(v.Type().Elem())) |
|||
} |
|||
v = v.Elem() |
|||
} |
|||
v = reflect.Indirect(v) |
|||
|
|||
if !v.CanSet() { |
|||
panic(fmt.Sprintf("unable to set idempotnecy token %v", v)) |
|||
} |
|||
|
|||
b := make([]byte, 16) |
|||
_, err := rand.Read(b) |
|||
if err != nil { |
|||
// TODO handle error
|
|||
return |
|||
} |
|||
|
|||
v.Set(reflect.ValueOf(UUIDVersion4(b))) |
|||
} |
|||
|
|||
// UUIDVersion4 returns a Version 4 random UUID from the byte slice provided
|
|||
func UUIDVersion4(u []byte) string { |
|||
// https://en.wikipedia.org/wiki/Universally_unique_identifier#Version_4_.28random.29
|
|||
// 13th character is "4"
|
|||
u[6] = (u[6] | 0x40) & 0x4F |
|||
// 17th character is "8", "9", "a", or "b"
|
|||
u[8] = (u[8] | 0x80) & 0xBF |
|||
|
|||
return fmt.Sprintf(`%X-%X-%X-%X-%X`, u[0:4], u[4:6], u[6:8], u[8:10], u[10:]) |
|||
} |
@ -0,0 +1,36 @@ |
|||
// Package query provides serialization of AWS query requests, and responses.
|
|||
package query |
|||
|
|||
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/query.json build_test.go
|
|||
|
|||
import ( |
|||
"net/url" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
"github.com/aws/aws-sdk-go/aws/request" |
|||
"github.com/aws/aws-sdk-go/private/protocol/query/queryutil" |
|||
) |
|||
|
|||
// BuildHandler is a named request handler for building query protocol requests
|
|||
var BuildHandler = request.NamedHandler{Name: "awssdk.query.Build", Fn: Build} |
|||
|
|||
// Build builds a request for an AWS Query service.
|
|||
func Build(r *request.Request) { |
|||
body := url.Values{ |
|||
"Action": {r.Operation.Name}, |
|||
"Version": {r.ClientInfo.APIVersion}, |
|||
} |
|||
if err := queryutil.Parse(body, r.Params, false); err != nil { |
|||
r.Error = awserr.New("SerializationError", "failed encoding Query request", err) |
|||
return |
|||
} |
|||
|
|||
if r.ExpireTime == 0 { |
|||
r.HTTPRequest.Method = "POST" |
|||
r.HTTPRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8") |
|||
r.SetBufferBody([]byte(body.Encode())) |
|||
} else { // This is a pre-signed request
|
|||
r.HTTPRequest.Method = "GET" |
|||
r.HTTPRequest.URL.RawQuery = body.Encode() |
|||
} |
|||
} |
@ -0,0 +1,230 @@ |
|||
package queryutil |
|||
|
|||
import ( |
|||
"encoding/base64" |
|||
"fmt" |
|||
"net/url" |
|||
"reflect" |
|||
"sort" |
|||
"strconv" |
|||
"strings" |
|||
"time" |
|||
|
|||
"github.com/aws/aws-sdk-go/private/protocol" |
|||
) |
|||
|
|||
// Parse parses an object i and fills a url.Values object. The isEC2 flag
|
|||
// indicates if this is the EC2 Query sub-protocol.
|
|||
func Parse(body url.Values, i interface{}, isEC2 bool) error { |
|||
q := queryParser{isEC2: isEC2} |
|||
return q.parseValue(body, reflect.ValueOf(i), "", "") |
|||
} |
|||
|
|||
func elemOf(value reflect.Value) reflect.Value { |
|||
for value.Kind() == reflect.Ptr { |
|||
value = value.Elem() |
|||
} |
|||
return value |
|||
} |
|||
|
|||
type queryParser struct { |
|||
isEC2 bool |
|||
} |
|||
|
|||
func (q *queryParser) parseValue(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error { |
|||
value = elemOf(value) |
|||
|
|||
// no need to handle zero values
|
|||
if !value.IsValid() { |
|||
return nil |
|||
} |
|||
|
|||
t := tag.Get("type") |
|||
if t == "" { |
|||
switch value.Kind() { |
|||
case reflect.Struct: |
|||
t = "structure" |
|||
case reflect.Slice: |
|||
t = "list" |
|||
case reflect.Map: |
|||
t = "map" |
|||
} |
|||
} |
|||
|
|||
switch t { |
|||
case "structure": |
|||
return q.parseStruct(v, value, prefix) |
|||
case "list": |
|||
return q.parseList(v, value, prefix, tag) |
|||
case "map": |
|||
return q.parseMap(v, value, prefix, tag) |
|||
default: |
|||
return q.parseScalar(v, value, prefix, tag) |
|||
} |
|||
} |
|||
|
|||
func (q *queryParser) parseStruct(v url.Values, value reflect.Value, prefix string) error { |
|||
if !value.IsValid() { |
|||
return nil |
|||
} |
|||
|
|||
t := value.Type() |
|||
for i := 0; i < value.NumField(); i++ { |
|||
elemValue := elemOf(value.Field(i)) |
|||
field := t.Field(i) |
|||
|
|||
if field.PkgPath != "" { |
|||
continue // ignore unexported fields
|
|||
} |
|||
|
|||
if protocol.CanSetIdempotencyToken(value.Field(i), field) { |
|||
token := protocol.GetIdempotencyToken() |
|||
elemValue = reflect.ValueOf(token) |
|||
} |
|||
|
|||
var name string |
|||
if q.isEC2 { |
|||
name = field.Tag.Get("queryName") |
|||
} |
|||
if name == "" { |
|||
if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" { |
|||
name = field.Tag.Get("locationNameList") |
|||
} else if locName := field.Tag.Get("locationName"); locName != "" { |
|||
name = locName |
|||
} |
|||
if name != "" && q.isEC2 { |
|||
name = strings.ToUpper(name[0:1]) + name[1:] |
|||
} |
|||
} |
|||
if name == "" { |
|||
name = field.Name |
|||
} |
|||
|
|||
if prefix != "" { |
|||
name = prefix + "." + name |
|||
} |
|||
|
|||
if err := q.parseValue(v, elemValue, name, field.Tag); err != nil { |
|||
return err |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func (q *queryParser) parseList(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error { |
|||
// If it's empty, generate an empty value
|
|||
if !value.IsNil() && value.Len() == 0 { |
|||
v.Set(prefix, "") |
|||
return nil |
|||
} |
|||
|
|||
// check for unflattened list member
|
|||
if !q.isEC2 && tag.Get("flattened") == "" { |
|||
prefix += ".member" |
|||
} |
|||
|
|||
for i := 0; i < value.Len(); i++ { |
|||
slicePrefix := prefix |
|||
if slicePrefix == "" { |
|||
slicePrefix = strconv.Itoa(i + 1) |
|||
} else { |
|||
slicePrefix = slicePrefix + "." + strconv.Itoa(i+1) |
|||
} |
|||
if err := q.parseValue(v, value.Index(i), slicePrefix, ""); err != nil { |
|||
return err |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func (q *queryParser) parseMap(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error { |
|||
// If it's empty, generate an empty value
|
|||
if !value.IsNil() && value.Len() == 0 { |
|||
v.Set(prefix, "") |
|||
return nil |
|||
} |
|||
|
|||
// check for unflattened list member
|
|||
if !q.isEC2 && tag.Get("flattened") == "" { |
|||
prefix += ".entry" |
|||
} |
|||
|
|||
// sort keys for improved serialization consistency.
|
|||
// this is not strictly necessary for protocol support.
|
|||
mapKeyValues := value.MapKeys() |
|||
mapKeys := map[string]reflect.Value{} |
|||
mapKeyNames := make([]string, len(mapKeyValues)) |
|||
for i, mapKey := range mapKeyValues { |
|||
name := mapKey.String() |
|||
mapKeys[name] = mapKey |
|||
mapKeyNames[i] = name |
|||
} |
|||
sort.Strings(mapKeyNames) |
|||
|
|||
for i, mapKeyName := range mapKeyNames { |
|||
mapKey := mapKeys[mapKeyName] |
|||
mapValue := value.MapIndex(mapKey) |
|||
|
|||
kname := tag.Get("locationNameKey") |
|||
if kname == "" { |
|||
kname = "key" |
|||
} |
|||
vname := tag.Get("locationNameValue") |
|||
if vname == "" { |
|||
vname = "value" |
|||
} |
|||
|
|||
// serialize key
|
|||
var keyName string |
|||
if prefix == "" { |
|||
keyName = strconv.Itoa(i+1) + "." + kname |
|||
} else { |
|||
keyName = prefix + "." + strconv.Itoa(i+1) + "." + kname |
|||
} |
|||
|
|||
if err := q.parseValue(v, mapKey, keyName, ""); err != nil { |
|||
return err |
|||
} |
|||
|
|||
// serialize value
|
|||
var valueName string |
|||
if prefix == "" { |
|||
valueName = strconv.Itoa(i+1) + "." + vname |
|||
} else { |
|||
valueName = prefix + "." + strconv.Itoa(i+1) + "." + vname |
|||
} |
|||
|
|||
if err := q.parseValue(v, mapValue, valueName, ""); err != nil { |
|||
return err |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
func (q *queryParser) parseScalar(v url.Values, r reflect.Value, name string, tag reflect.StructTag) error { |
|||
switch value := r.Interface().(type) { |
|||
case string: |
|||
v.Set(name, value) |
|||
case []byte: |
|||
if !r.IsNil() { |
|||
v.Set(name, base64.StdEncoding.EncodeToString(value)) |
|||
} |
|||
case bool: |
|||
v.Set(name, strconv.FormatBool(value)) |
|||
case int64: |
|||
v.Set(name, strconv.FormatInt(value, 10)) |
|||
case int: |
|||
v.Set(name, strconv.Itoa(value)) |
|||
case float64: |
|||
v.Set(name, strconv.FormatFloat(value, 'f', -1, 64)) |
|||
case float32: |
|||
v.Set(name, strconv.FormatFloat(float64(value), 'f', -1, 32)) |
|||
case time.Time: |
|||
const ISO8601UTC = "2006-01-02T15:04:05Z" |
|||
v.Set(name, value.UTC().Format(ISO8601UTC)) |
|||
default: |
|||
return fmt.Errorf("unsupported value for param %s: %v (%s)", name, r.Interface(), r.Type().Name()) |
|||
} |
|||
return nil |
|||
} |
@ -0,0 +1,35 @@ |
|||
package query |
|||
|
|||
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/query.json unmarshal_test.go
|
|||
|
|||
import ( |
|||
"encoding/xml" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
"github.com/aws/aws-sdk-go/aws/request" |
|||
"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil" |
|||
) |
|||
|
|||
// UnmarshalHandler is a named request handler for unmarshaling query protocol requests
|
|||
var UnmarshalHandler = request.NamedHandler{Name: "awssdk.query.Unmarshal", Fn: Unmarshal} |
|||
|
|||
// UnmarshalMetaHandler is a named request handler for unmarshaling query protocol request metadata
|
|||
var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.query.UnmarshalMeta", Fn: UnmarshalMeta} |
|||
|
|||
// Unmarshal unmarshals a response for an AWS Query service.
|
|||
func Unmarshal(r *request.Request) { |
|||
defer r.HTTPResponse.Body.Close() |
|||
if r.DataFilled() { |
|||
decoder := xml.NewDecoder(r.HTTPResponse.Body) |
|||
err := xmlutil.UnmarshalXML(r.Data, decoder, r.Operation.Name+"Result") |
|||
if err != nil { |
|||
r.Error = awserr.New("SerializationError", "failed decoding Query response", err) |
|||
return |
|||
} |
|||
} |
|||
} |
|||
|
|||
// UnmarshalMeta unmarshals header response values for an AWS Query service.
|
|||
func UnmarshalMeta(r *request.Request) { |
|||
r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid") |
|||
} |
@ -0,0 +1,66 @@ |
|||
package query |
|||
|
|||
import ( |
|||
"encoding/xml" |
|||
"io/ioutil" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
"github.com/aws/aws-sdk-go/aws/request" |
|||
) |
|||
|
|||
type xmlErrorResponse struct { |
|||
XMLName xml.Name `xml:"ErrorResponse"` |
|||
Code string `xml:"Error>Code"` |
|||
Message string `xml:"Error>Message"` |
|||
RequestID string `xml:"RequestId"` |
|||
} |
|||
|
|||
type xmlServiceUnavailableResponse struct { |
|||
XMLName xml.Name `xml:"ServiceUnavailableException"` |
|||
} |
|||
|
|||
// UnmarshalErrorHandler is a name request handler to unmarshal request errors
|
|||
var UnmarshalErrorHandler = request.NamedHandler{Name: "awssdk.query.UnmarshalError", Fn: UnmarshalError} |
|||
|
|||
// UnmarshalError unmarshals an error response for an AWS Query service.
|
|||
func UnmarshalError(r *request.Request) { |
|||
defer r.HTTPResponse.Body.Close() |
|||
|
|||
bodyBytes, err := ioutil.ReadAll(r.HTTPResponse.Body) |
|||
if err != nil { |
|||
r.Error = awserr.New("SerializationError", "failed to read from query HTTP response body", err) |
|||
return |
|||
} |
|||
|
|||
// First check for specific error
|
|||
resp := xmlErrorResponse{} |
|||
decodeErr := xml.Unmarshal(bodyBytes, &resp) |
|||
if decodeErr == nil { |
|||
reqID := resp.RequestID |
|||
if reqID == "" { |
|||
reqID = r.RequestID |
|||
} |
|||
r.Error = awserr.NewRequestFailure( |
|||
awserr.New(resp.Code, resp.Message, nil), |
|||
r.HTTPResponse.StatusCode, |
|||
reqID, |
|||
) |
|||
return |
|||
} |
|||
|
|||
// Check for unhandled error
|
|||
servUnavailResp := xmlServiceUnavailableResponse{} |
|||
unavailErr := xml.Unmarshal(bodyBytes, &servUnavailResp) |
|||
if unavailErr == nil { |
|||
r.Error = awserr.NewRequestFailure( |
|||
awserr.New("ServiceUnavailableException", "service is unavailable", nil), |
|||
r.HTTPResponse.StatusCode, |
|||
r.RequestID, |
|||
) |
|||
return |
|||
} |
|||
|
|||
// Failed to retrieve any error message from the response body
|
|||
r.Error = awserr.New("SerializationError", |
|||
"failed to decode query XML error response", decodeErr) |
|||
} |
@ -0,0 +1,256 @@ |
|||
// Package rest provides RESTful serialization of AWS requests and responses.
|
|||
package rest |
|||
|
|||
import ( |
|||
"bytes" |
|||
"encoding/base64" |
|||
"fmt" |
|||
"io" |
|||
"net/http" |
|||
"net/url" |
|||
"path" |
|||
"reflect" |
|||
"strconv" |
|||
"strings" |
|||
"time" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
"github.com/aws/aws-sdk-go/aws/request" |
|||
) |
|||
|
|||
// RFC822 returns an RFC822 formatted timestamp for AWS protocols
|
|||
const RFC822 = "Mon, 2 Jan 2006 15:04:05 GMT" |
|||
|
|||
// Whether the byte value can be sent without escaping in AWS URLs
|
|||
var noEscape [256]bool |
|||
|
|||
var errValueNotSet = fmt.Errorf("value not set") |
|||
|
|||
func init() { |
|||
for i := 0; i < len(noEscape); i++ { |
|||
// AWS expects every character except these to be escaped
|
|||
noEscape[i] = (i >= 'A' && i <= 'Z') || |
|||
(i >= 'a' && i <= 'z') || |
|||
(i >= '0' && i <= '9') || |
|||
i == '-' || |
|||
i == '.' || |
|||
i == '_' || |
|||
i == '~' |
|||
} |
|||
} |
|||
|
|||
// BuildHandler is a named request handler for building rest protocol requests
|
|||
var BuildHandler = request.NamedHandler{Name: "awssdk.rest.Build", Fn: Build} |
|||
|
|||
// Build builds the REST component of a service request.
|
|||
func Build(r *request.Request) { |
|||
if r.ParamsFilled() { |
|||
v := reflect.ValueOf(r.Params).Elem() |
|||
buildLocationElements(r, v) |
|||
buildBody(r, v) |
|||
} |
|||
} |
|||
|
|||
func buildLocationElements(r *request.Request, v reflect.Value) { |
|||
query := r.HTTPRequest.URL.Query() |
|||
|
|||
for i := 0; i < v.NumField(); i++ { |
|||
m := v.Field(i) |
|||
if n := v.Type().Field(i).Name; n[0:1] == strings.ToLower(n[0:1]) { |
|||
continue |
|||
} |
|||
|
|||
if m.IsValid() { |
|||
field := v.Type().Field(i) |
|||
name := field.Tag.Get("locationName") |
|||
if name == "" { |
|||
name = field.Name |
|||
} |
|||
if m.Kind() == reflect.Ptr { |
|||
m = m.Elem() |
|||
} |
|||
if !m.IsValid() { |
|||
continue |
|||
} |
|||
|
|||
var err error |
|||
switch field.Tag.Get("location") { |
|||
case "headers": // header maps
|
|||
err = buildHeaderMap(&r.HTTPRequest.Header, m, field.Tag.Get("locationName")) |
|||
case "header": |
|||
err = buildHeader(&r.HTTPRequest.Header, m, name) |
|||
case "uri": |
|||
err = buildURI(r.HTTPRequest.URL, m, name) |
|||
case "querystring": |
|||
err = buildQueryString(query, m, name) |
|||
} |
|||
r.Error = err |
|||
} |
|||
if r.Error != nil { |
|||
return |
|||
} |
|||
} |
|||
|
|||
r.HTTPRequest.URL.RawQuery = query.Encode() |
|||
updatePath(r.HTTPRequest.URL, r.HTTPRequest.URL.Path) |
|||
} |
|||
|
|||
func buildBody(r *request.Request, v reflect.Value) { |
|||
if field, ok := v.Type().FieldByName("_"); ok { |
|||
if payloadName := field.Tag.Get("payload"); payloadName != "" { |
|||
pfield, _ := v.Type().FieldByName(payloadName) |
|||
if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" { |
|||
payload := reflect.Indirect(v.FieldByName(payloadName)) |
|||
if payload.IsValid() && payload.Interface() != nil { |
|||
switch reader := payload.Interface().(type) { |
|||
case io.ReadSeeker: |
|||
r.SetReaderBody(reader) |
|||
case []byte: |
|||
r.SetBufferBody(reader) |
|||
case string: |
|||
r.SetStringBody(reader) |
|||
default: |
|||
r.Error = awserr.New("SerializationError", |
|||
"failed to encode REST request", |
|||
fmt.Errorf("unknown payload type %s", payload.Type())) |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
func buildHeader(header *http.Header, v reflect.Value, name string) error { |
|||
str, err := convertType(v) |
|||
if err == errValueNotSet { |
|||
return nil |
|||
} else if err != nil { |
|||
return awserr.New("SerializationError", "failed to encode REST request", err) |
|||
} |
|||
|
|||
header.Add(name, str) |
|||
|
|||
return nil |
|||
} |
|||
|
|||
func buildHeaderMap(header *http.Header, v reflect.Value, prefix string) error { |
|||
for _, key := range v.MapKeys() { |
|||
str, err := convertType(v.MapIndex(key)) |
|||
if err == errValueNotSet { |
|||
continue |
|||
} else if err != nil { |
|||
return awserr.New("SerializationError", "failed to encode REST request", err) |
|||
|
|||
} |
|||
|
|||
header.Add(prefix+key.String(), str) |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func buildURI(u *url.URL, v reflect.Value, name string) error { |
|||
value, err := convertType(v) |
|||
if err == errValueNotSet { |
|||
return nil |
|||
} else if err != nil { |
|||
return awserr.New("SerializationError", "failed to encode REST request", err) |
|||
} |
|||
|
|||
uri := u.Path |
|||
uri = strings.Replace(uri, "{"+name+"}", EscapePath(value, true), -1) |
|||
uri = strings.Replace(uri, "{"+name+"+}", EscapePath(value, false), -1) |
|||
u.Path = uri |
|||
|
|||
return nil |
|||
} |
|||
|
|||
func buildQueryString(query url.Values, v reflect.Value, name string) error { |
|||
switch value := v.Interface().(type) { |
|||
case []*string: |
|||
for _, item := range value { |
|||
query.Add(name, *item) |
|||
} |
|||
case map[string]*string: |
|||
for key, item := range value { |
|||
query.Add(key, *item) |
|||
} |
|||
case map[string][]*string: |
|||
for key, items := range value { |
|||
for _, item := range items { |
|||
query.Add(key, *item) |
|||
} |
|||
} |
|||
default: |
|||
str, err := convertType(v) |
|||
if err == errValueNotSet { |
|||
return nil |
|||
} else if err != nil { |
|||
return awserr.New("SerializationError", "failed to encode REST request", err) |
|||
} |
|||
query.Set(name, str) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
func updatePath(url *url.URL, urlPath string) { |
|||
scheme, query := url.Scheme, url.RawQuery |
|||
|
|||
hasSlash := strings.HasSuffix(urlPath, "/") |
|||
|
|||
// clean up path
|
|||
urlPath = path.Clean(urlPath) |
|||
if hasSlash && !strings.HasSuffix(urlPath, "/") { |
|||
urlPath += "/" |
|||
} |
|||
|
|||
// get formatted URL minus scheme so we can build this into Opaque
|
|||
url.Scheme, url.Path, url.RawQuery = "", "", "" |
|||
s := url.String() |
|||
url.Scheme = scheme |
|||
url.RawQuery = query |
|||
|
|||
// build opaque URI
|
|||
url.Opaque = s + urlPath |
|||
} |
|||
|
|||
// EscapePath escapes part of a URL path in Amazon style
|
|||
func EscapePath(path string, encodeSep bool) string { |
|||
var buf bytes.Buffer |
|||
for i := 0; i < len(path); i++ { |
|||
c := path[i] |
|||
if noEscape[c] || (c == '/' && !encodeSep) { |
|||
buf.WriteByte(c) |
|||
} else { |
|||
fmt.Fprintf(&buf, "%%%02X", c) |
|||
} |
|||
} |
|||
return buf.String() |
|||
} |
|||
|
|||
func convertType(v reflect.Value) (string, error) { |
|||
v = reflect.Indirect(v) |
|||
if !v.IsValid() { |
|||
return "", errValueNotSet |
|||
} |
|||
|
|||
var str string |
|||
switch value := v.Interface().(type) { |
|||
case string: |
|||
str = value |
|||
case []byte: |
|||
str = base64.StdEncoding.EncodeToString(value) |
|||
case bool: |
|||
str = strconv.FormatBool(value) |
|||
case int64: |
|||
str = strconv.FormatInt(value, 10) |
|||
case float64: |
|||
str = strconv.FormatFloat(value, 'f', -1, 64) |
|||
case time.Time: |
|||
str = value.UTC().Format(RFC822) |
|||
default: |
|||
err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type()) |
|||
return "", err |
|||
} |
|||
return str, nil |
|||
} |
@ -0,0 +1,45 @@ |
|||
package rest |
|||
|
|||
import "reflect" |
|||
|
|||
// PayloadMember returns the payload field member of i if there is one, or nil.
|
|||
func PayloadMember(i interface{}) interface{} { |
|||
if i == nil { |
|||
return nil |
|||
} |
|||
|
|||
v := reflect.ValueOf(i).Elem() |
|||
if !v.IsValid() { |
|||
return nil |
|||
} |
|||
if field, ok := v.Type().FieldByName("_"); ok { |
|||
if payloadName := field.Tag.Get("payload"); payloadName != "" { |
|||
field, _ := v.Type().FieldByName(payloadName) |
|||
if field.Tag.Get("type") != "structure" { |
|||
return nil |
|||
} |
|||
|
|||
payload := v.FieldByName(payloadName) |
|||
if payload.IsValid() || (payload.Kind() == reflect.Ptr && !payload.IsNil()) { |
|||
return payload.Interface() |
|||
} |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// PayloadType returns the type of a payload field member of i if there is one, or "".
|
|||
func PayloadType(i interface{}) string { |
|||
v := reflect.Indirect(reflect.ValueOf(i)) |
|||
if !v.IsValid() { |
|||
return "" |
|||
} |
|||
if field, ok := v.Type().FieldByName("_"); ok { |
|||
if payloadName := field.Tag.Get("payload"); payloadName != "" { |
|||
if member, ok := v.Type().FieldByName(payloadName); ok { |
|||
return member.Tag.Get("type") |
|||
} |
|||
} |
|||
} |
|||
return "" |
|||
} |
@ -0,0 +1,198 @@ |
|||
package rest |
|||
|
|||
import ( |
|||
"encoding/base64" |
|||
"fmt" |
|||
"io" |
|||
"io/ioutil" |
|||
"net/http" |
|||
"reflect" |
|||
"strconv" |
|||
"strings" |
|||
"time" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws" |
|||
"github.com/aws/aws-sdk-go/aws/awserr" |
|||
"github.com/aws/aws-sdk-go/aws/request" |
|||
) |
|||
|
|||
// UnmarshalHandler is a named request handler for unmarshaling rest protocol requests
|
|||
var UnmarshalHandler = request.NamedHandler{Name: "awssdk.rest.Unmarshal", Fn: Unmarshal} |
|||
|
|||
// UnmarshalMetaHandler is a named request handler for unmarshaling rest protocol request metadata
|
|||
var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.rest.UnmarshalMeta", Fn: UnmarshalMeta} |
|||
|
|||
// Unmarshal unmarshals the REST component of a response in a REST service.
|
|||
func Unmarshal(r *request.Request) { |
|||
if r.DataFilled() { |
|||
v := reflect.Indirect(reflect.ValueOf(r.Data)) |
|||
unmarshalBody(r, v) |
|||
} |
|||
} |
|||
|
|||
// UnmarshalMeta unmarshals the REST metadata of a response in a REST service
|
|||
func UnmarshalMeta(r *request.Request) { |
|||
r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid") |
|||
if r.RequestID == "" { |
|||
// Alternative version of request id in the header
|
|||
r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id") |
|||
} |
|||
if r.DataFilled() { |
|||
v := reflect.Indirect(reflect.ValueOf(r.Data)) |
|||
unmarshalLocationElements(r, v) |
|||
} |
|||
} |
|||
|
|||
func unmarshalBody(r *request.Request, v reflect.Value) { |
|||
if field, ok := v.Type().FieldByName("_"); ok { |
|||
if payloadName := field.Tag.Get("payload"); payloadName != "" { |
|||
pfield, _ := v.Type().FieldByName(payloadName) |
|||
if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" { |
|||
payload := v.FieldByName(payloadName) |
|||
if payload.IsValid() { |
|||
switch payload.Interface().(type) { |
|||
case []byte: |
|||
defer r.HTTPResponse.Body.Close() |
|||
b, err := ioutil.ReadAll(r.HTTPResponse.Body) |
|||
if err != nil { |
|||
r.Error = awserr.New("SerializationError", "failed to decode REST response", err) |
|||
} else { |
|||
payload.Set(reflect.ValueOf(b)) |
|||
} |
|||
case *string: |
|||
defer r.HTTPResponse.Body.Close() |
|||
b, err := ioutil.ReadAll(r.HTTPResponse.Body) |
|||
if err != nil { |
|||
r.Error = awserr.New("SerializationError", "failed to decode REST response", err) |
|||
} else { |
|||
str := string(b) |
|||
payload.Set(reflect.ValueOf(&str)) |
|||
} |
|||
default: |
|||
switch payload.Type().String() { |
|||
case "io.ReadSeeker": |
|||
payload.Set(reflect.ValueOf(aws.ReadSeekCloser(r.HTTPResponse.Body))) |
|||
case "aws.ReadSeekCloser", "io.ReadCloser": |
|||
payload.Set(reflect.ValueOf(r.HTTPResponse.Body)) |
|||
default: |
|||
io.Copy(ioutil.Discard, r.HTTPResponse.Body) |
|||
defer r.HTTPResponse.Body.Close() |
|||
r.Error = awserr.New("SerializationError", |
|||
"failed to decode REST response", |
|||
fmt.Errorf("unknown payload type %s", payload.Type())) |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
func unmarshalLocationElements(r *request.Request, v reflect.Value) { |
|||
for i := 0; i < v.NumField(); i++ { |
|||
m, field := v.Field(i), v.Type().Field(i) |
|||
if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) { |
|||
continue |
|||
} |
|||
|
|||
if m.IsValid() { |
|||
name := field.Tag.Get("locationName") |
|||
if name == "" { |
|||
name = field.Name |
|||
} |
|||
|
|||
switch field.Tag.Get("location") { |
|||
case "statusCode": |
|||
unmarshalStatusCode(m, r.HTTPResponse.StatusCode) |
|||
case "header": |
|||
err := unmarshalHeader(m, r.HTTPResponse.Header.Get(name)) |
|||
if err != nil { |
|||
r.Error = awserr.New("SerializationError", "failed to decode REST response", err) |
|||
break |
|||
} |
|||
case "headers": |
|||
prefix := field.Tag.Get("locationName") |
|||
err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix) |
|||
if err != nil { |
|||
r.Error = awserr.New("SerializationError", "failed to decode REST response", err) |
|||
break |
|||
} |
|||
} |
|||
} |
|||
if r.Error != nil { |
|||
return |
|||
} |
|||
} |
|||
} |
|||
|
|||
func unmarshalStatusCode(v reflect.Value, statusCode int) { |
|||
if !v.IsValid() { |
|||
return |
|||
} |
|||
|
|||
switch v.Interface().(type) { |
|||
case *int64: |
|||
s := int64(statusCode) |
|||
v.Set(reflect.ValueOf(&s)) |
|||
} |
|||
} |
|||
|
|||
func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string) error { |
|||
switch r.Interface().(type) { |
|||
case map[string]*string: // we only support string map value types
|
|||
out := map[string]*string{} |
|||
for k, v := range headers { |
|||
k = http.CanonicalHeaderKey(k) |
|||
if strings.HasPrefix(strings.ToLower(k), strings.ToLower(prefix)) { |
|||
out[k[len(prefix):]] = &v[0] |
|||
} |
|||
} |
|||
r.Set(reflect.ValueOf(out)) |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func unmarshalHeader(v reflect.Value, header string) error { |
|||
if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) { |
|||
return nil |
|||
} |
|||
|
|||
switch v.Interface().(type) { |
|||
case *string: |
|||
v.Set(reflect.ValueOf(&header)) |
|||
case []byte: |
|||
b, err := base64.StdEncoding.DecodeString(header) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
v.Set(reflect.ValueOf(&b)) |
|||
case *bool: |
|||
b, err := strconv.ParseBool(header) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
v.Set(reflect.ValueOf(&b)) |
|||
case *int64: |
|||
i, err := strconv.ParseInt(header, 10, 64) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
v.Set(reflect.ValueOf(&i)) |
|||
case *float64: |
|||
f, err := strconv.ParseFloat(header, 64) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
v.Set(reflect.ValueOf(&f)) |
|||
case *time.Time: |
|||
t, err := time.Parse(RFC822, header) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
v.Set(reflect.ValueOf(&t)) |
|||
default: |
|||
err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type()) |
|||
return err |
|||
} |
|||
return nil |
|||
} |
@ -0,0 +1,21 @@ |
|||
package protocol |
|||
|
|||
import ( |
|||
"io" |
|||
"io/ioutil" |
|||
|
|||
"github.com/aws/aws-sdk-go/aws/request" |
|||
) |
|||
|
|||
// UnmarshalDiscardBodyHandler is a named request handler to empty and close a response's body
|
|||
var UnmarshalDiscardBodyHandler = request.NamedHandler{Name: "awssdk.shared.UnmarshalDiscardBody", Fn: UnmarshalDiscardBody} |
|||
|
|||
// UnmarshalDiscardBody is a request handler to empty a response's body and closing it.
|
|||
func UnmarshalDiscardBody(r *request.Request) { |
|||
if r.HTTPResponse == nil || r.HTTPResponse.Body == nil { |
|||
return |
|||
} |
|||
|
|||
io.Copy(ioutil.Discard, r.HTTPResponse.Body) |
|||
r.HTTPResponse.Body.Close() |
|||
} |
@ -0,0 +1,293 @@ |
|||
// Package xmlutil provides XML serialization of AWS requests and responses.
|
|||
package xmlutil |
|||
|
|||
import ( |
|||
"encoding/base64" |
|||
"encoding/xml" |
|||
"fmt" |
|||
"reflect" |
|||
"sort" |
|||
"strconv" |
|||
"time" |
|||
|
|||
"github.com/aws/aws-sdk-go/private/protocol" |
|||
) |
|||
|
|||
// BuildXML will serialize params into an xml.Encoder.
|
|||
// Error will be returned if the serialization of any of the params or nested values fails.
|
|||
func BuildXML(params interface{}, e *xml.Encoder) error { |
|||
b := xmlBuilder{encoder: e, namespaces: map[string]string{}} |
|||
root := NewXMLElement(xml.Name{}) |
|||
if err := b.buildValue(reflect.ValueOf(params), root, ""); err != nil { |
|||
return err |
|||
} |
|||
for _, c := range root.Children { |
|||
for _, v := range c { |
|||
return StructToXML(e, v, false) |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// Returns the reflection element of a value, if it is a pointer.
|
|||
func elemOf(value reflect.Value) reflect.Value { |
|||
for value.Kind() == reflect.Ptr { |
|||
value = value.Elem() |
|||
} |
|||
return value |
|||
} |
|||
|
|||
// A xmlBuilder serializes values from Go code to XML
|
|||
type xmlBuilder struct { |
|||
encoder *xml.Encoder |
|||
namespaces map[string]string |
|||
} |
|||
|
|||
// buildValue generic XMLNode builder for any type. Will build value for their specific type
|
|||
// struct, list, map, scalar.
|
|||
//
|
|||
// Also takes a "type" tag value to set what type a value should be converted to XMLNode as. If
|
|||
// type is not provided reflect will be used to determine the value's type.
|
|||
func (b *xmlBuilder) buildValue(value reflect.Value, current *XMLNode, tag reflect.StructTag) error { |
|||
value = elemOf(value) |
|||
if !value.IsValid() { // no need to handle zero values
|
|||
return nil |
|||
} else if tag.Get("location") != "" { // don't handle non-body location values
|
|||
return nil |
|||
} |
|||
|
|||
t := tag.Get("type") |
|||
if t == "" { |
|||
switch value.Kind() { |
|||
case reflect.Struct: |
|||
t = "structure" |
|||
case reflect.Slice: |
|||
t = "list" |
|||
case reflect.Map: |
|||
t = "map" |
|||
} |
|||
} |
|||
|
|||
switch t { |
|||
case "structure": |
|||
if field, ok := value.Type().FieldByName("_"); ok { |
|||
tag = tag + reflect.StructTag(" ") + field.Tag |
|||
} |
|||
return b.buildStruct(value, current, tag) |
|||
case "list": |
|||
return b.buildList(value, current, tag) |
|||
case "map": |
|||
return b.buildMap(value, current, tag) |
|||
default: |
|||
return b.buildScalar(value, current, tag) |
|||
} |
|||
} |
|||
|
|||
// buildStruct adds a struct and its fields to the current XMLNode. All fields any any nested
|
|||
// types are converted to XMLNodes also.
|
|||
func (b *xmlBuilder) buildStruct(value reflect.Value, current *XMLNode, tag reflect.StructTag) error { |
|||
if !value.IsValid() { |
|||
return nil |
|||
} |
|||
|
|||
fieldAdded := false |
|||
|
|||
// unwrap payloads
|
|||
if payload := tag.Get("payload"); payload != "" { |
|||
field, _ := value.Type().FieldByName(payload) |
|||
tag = field.Tag |
|||
value = elemOf(value.FieldByName(payload)) |
|||
|
|||
if !value.IsValid() { |
|||
return nil |
|||
} |
|||
} |
|||
|
|||
child := NewXMLElement(xml.Name{Local: tag.Get("locationName")}) |
|||
|
|||
// there is an xmlNamespace associated with this struct
|
|||
if prefix, uri := tag.Get("xmlPrefix"), tag.Get("xmlURI"); uri != "" { |
|||
ns := xml.Attr{ |
|||
Name: xml.Name{Local: "xmlns"}, |
|||
Value: uri, |
|||
} |
|||
if prefix != "" { |
|||
b.namespaces[prefix] = uri // register the namespace
|
|||
ns.Name.Local = "xmlns:" + prefix |
|||
} |
|||
|
|||
child.Attr = append(child.Attr, ns) |
|||
} |
|||
|
|||
t := value.Type() |
|||
for i := 0; i < value.NumField(); i++ { |
|||
member := elemOf(value.Field(i)) |
|||
field := t.Field(i) |
|||
|
|||
if field.PkgPath != "" { |
|||
continue // ignore unexported fields
|
|||
} |
|||
|
|||
mTag := field.Tag |
|||
if mTag.Get("location") != "" { // skip non-body members
|
|||
continue |
|||
} |
|||
|
|||
if protocol.CanSetIdempotencyToken(value.Field(i), field) { |
|||
token := protocol.GetIdempotencyToken() |
|||
member = reflect.ValueOf(token) |
|||
} |
|||
|
|||
memberName := mTag.Get("locationName") |
|||
if memberName == "" { |
|||
memberName = field.Name |
|||
mTag = reflect.StructTag(string(mTag) + ` locationName:"` + memberName + `"`) |
|||
} |
|||
if err := b.buildValue(member, child, mTag); err != nil { |
|||
return err |
|||
} |
|||
|
|||
fieldAdded = true |
|||
} |
|||
|
|||
if fieldAdded { // only append this child if we have one ore more valid members
|
|||
current.AddChild(child) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// buildList adds the value's list items to the current XMLNode as children nodes. All
|
|||
// nested values in the list are converted to XMLNodes also.
|
|||
func (b *xmlBuilder) buildList(value reflect.Value, current *XMLNode, tag reflect.StructTag) error { |
|||
if value.IsNil() { // don't build omitted lists
|
|||
return nil |
|||
} |
|||
|
|||
// check for unflattened list member
|
|||
flattened := tag.Get("flattened") != "" |
|||
|
|||
xname := xml.Name{Local: tag.Get("locationName")} |
|||
if flattened { |
|||
for i := 0; i < value.Len(); i++ { |
|||
child := NewXMLElement(xname) |
|||
current.AddChild(child) |
|||
if err := b.buildValue(value.Index(i), child, ""); err != nil { |
|||
return err |
|||
} |
|||
} |
|||
} else { |
|||
list := NewXMLElement(xname) |
|||
current.AddChild(list) |
|||
|
|||
for i := 0; i < value.Len(); i++ { |
|||
iname := tag.Get("locationNameList") |
|||
if iname == "" { |
|||
iname = "member" |
|||
} |
|||
|
|||
child := NewXMLElement(xml.Name{Local: iname}) |
|||
list.AddChild(child) |
|||
if err := b.buildValue(value.Index(i), child, ""); err != nil { |
|||
return err |
|||
} |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// buildMap adds the value's key/value pairs to the current XMLNode as children nodes. All
|
|||
// nested values in the map are converted to XMLNodes also.
|
|||
//
|
|||
// Error will be returned if it is unable to build the map's values into XMLNodes
|
|||
func (b *xmlBuilder) buildMap(value reflect.Value, current *XMLNode, tag reflect.StructTag) error { |
|||
if value.IsNil() { // don't build omitted maps
|
|||
return nil |
|||
} |
|||
|
|||
maproot := NewXMLElement(xml.Name{Local: tag.Get("locationName")}) |
|||
current.AddChild(maproot) |
|||
current = maproot |
|||
|
|||
kname, vname := "key", "value" |
|||
if n := tag.Get("locationNameKey"); n != "" { |
|||
kname = n |
|||
} |
|||
if n := tag.Get("locationNameValue"); n != "" { |
|||
vname = n |
|||
} |
|||
|
|||
// sorting is not required for compliance, but it makes testing easier
|
|||
keys := make([]string, value.Len()) |
|||
for i, k := range value.MapKeys() { |
|||
keys[i] = k.String() |
|||
} |
|||
sort.Strings(keys) |
|||
|
|||
for _, k := range keys { |
|||
v := value.MapIndex(reflect.ValueOf(k)) |
|||
|
|||
mapcur := current |
|||
if tag.Get("flattened") == "" { // add "entry" tag to non-flat maps
|
|||
child := NewXMLElement(xml.Name{Local: "entry"}) |
|||
mapcur.AddChild(child) |
|||
mapcur = child |
|||
} |
|||
|
|||
kchild := NewXMLElement(xml.Name{Local: kname}) |
|||
kchild.Text = k |
|||
vchild := NewXMLElement(xml.Name{Local: vname}) |
|||
mapcur.AddChild(kchild) |
|||
mapcur.AddChild(vchild) |
|||
|
|||
if err := b.buildValue(v, vchild, ""); err != nil { |
|||
return err |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// buildScalar will convert the value into a string and append it as a attribute or child
|
|||
// of the current XMLNode.
|
|||
//
|
|||
// The value will be added as an attribute if tag contains a "xmlAttribute" attribute value.
|
|||
//
|
|||
// Error will be returned if the value type is unsupported.
|
|||
func (b *xmlBuilder) buildScalar(value reflect.Value, current *XMLNode, tag reflect.StructTag) error { |
|||
var str string |
|||
switch converted := value.Interface().(type) { |
|||
case string: |
|||
str = converted |
|||
case []byte: |
|||
if !value.IsNil() { |
|||
str = base64.StdEncoding.EncodeToString(converted) |
|||
} |
|||
case bool: |
|||
str = strconv.FormatBool(converted) |
|||
case int64: |
|||
str = strconv.FormatInt(converted, 10) |
|||
case int: |
|||
str = strconv.Itoa(converted) |
|||
case float64: |
|||
str = strconv.FormatFloat(converted, 'f', -1, 64) |
|||
case float32: |
|||
str = strconv.FormatFloat(float64(converted), 'f', -1, 32) |
|||
case time.Time: |
|||
const ISO8601UTC = "2006-01-02T15:04:05Z" |
|||
str = converted.UTC().Format(ISO8601UTC) |
|||
default: |
|||
return fmt.Errorf("unsupported value for param %s: %v (%s)", |
|||
tag.Get("locationName"), value.Interface(), value.Type().Name()) |
|||
} |
|||
|
|||
xname := xml.Name{Local: tag.Get("locationName")} |
|||
if tag.Get("xmlAttribute") != "" { // put into current node's attribute list
|
|||
attr := xml.Attr{Name: xname, Value: str} |
|||
current.Attr = append(current.Attr, attr) |
|||
} else { // regular text node
|
|||
current.AddChild(&XMLNode{Name: xname, Text: str}) |
|||
} |
|||
return nil |
|||
} |
@ -0,0 +1,260 @@ |
|||
package xmlutil |
|||
|
|||
import ( |
|||
"encoding/base64" |
|||
"encoding/xml" |
|||
"fmt" |
|||
"io" |
|||
"reflect" |
|||
"strconv" |
|||
"strings" |
|||
"time" |
|||
) |
|||
|
|||
// UnmarshalXML deserializes an xml.Decoder into the container v. V
|
|||
// needs to match the shape of the XML expected to be decoded.
|
|||
// If the shape doesn't match unmarshaling will fail.
|
|||
func UnmarshalXML(v interface{}, d *xml.Decoder, wrapper string) error { |
|||
n, _ := XMLToStruct(d, nil) |
|||
if n.Children != nil { |
|||
for _, root := range n.Children { |
|||
for _, c := range root { |
|||
if wrappedChild, ok := c.Children[wrapper]; ok { |
|||
c = wrappedChild[0] // pull out wrapped element
|
|||
} |
|||
|
|||
err := parse(reflect.ValueOf(v), c, "") |
|||
if err != nil { |
|||
if err == io.EOF { |
|||
return nil |
|||
} |
|||
return err |
|||
} |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// parse deserializes any value from the XMLNode. The type tag is used to infer the type, or reflect
|
|||
// will be used to determine the type from r.
|
|||
func parse(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { |
|||
rtype := r.Type() |
|||
if rtype.Kind() == reflect.Ptr { |
|||
rtype = rtype.Elem() // check kind of actual element type
|
|||
} |
|||
|
|||
t := tag.Get("type") |
|||
if t == "" { |
|||
switch rtype.Kind() { |
|||
case reflect.Struct: |
|||
t = "structure" |
|||
case reflect.Slice: |
|||
t = "list" |
|||
case reflect.Map: |
|||
t = "map" |
|||
} |
|||
} |
|||
|
|||
switch t { |
|||
case "structure": |
|||
if field, ok := rtype.FieldByName("_"); ok { |
|||
tag = field.Tag |
|||
} |
|||
return parseStruct(r, node, tag) |
|||
case "list": |
|||
return parseList(r, node, tag) |
|||
case "map": |
|||
return parseMap(r, node, tag) |
|||
default: |
|||
return parseScalar(r, node, tag) |
|||
} |
|||
} |
|||
|
|||
// parseStruct deserializes a structure and its fields from an XMLNode. Any nested
|
|||
// types in the structure will also be deserialized.
|
|||
func parseStruct(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { |
|||
t := r.Type() |
|||
if r.Kind() == reflect.Ptr { |
|||
if r.IsNil() { // create the structure if it's nil
|
|||
s := reflect.New(r.Type().Elem()) |
|||
r.Set(s) |
|||
r = s |
|||
} |
|||
|
|||
r = r.Elem() |
|||
t = t.Elem() |
|||
} |
|||
|
|||
// unwrap any payloads
|
|||
if payload := tag.Get("payload"); payload != "" { |
|||
field, _ := t.FieldByName(payload) |
|||
return parseStruct(r.FieldByName(payload), node, field.Tag) |
|||
} |
|||
|
|||
for i := 0; i < t.NumField(); i++ { |
|||
field := t.Field(i) |
|||
if c := field.Name[0:1]; strings.ToLower(c) == c { |
|||
continue // ignore unexported fields
|
|||
} |
|||
|
|||
// figure out what this field is called
|
|||
name := field.Name |
|||
if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" { |
|||
name = field.Tag.Get("locationNameList") |
|||
} else if locName := field.Tag.Get("locationName"); locName != "" { |
|||
name = locName |
|||
} |
|||
|
|||
// try to find the field by name in elements
|
|||
elems := node.Children[name] |
|||
|
|||
if elems == nil { // try to find the field in attributes
|
|||
for _, a := range node.Attr { |
|||
if name == a.Name.Local { |
|||
// turn this into a text node for de-serializing
|
|||
elems = []*XMLNode{{Text: a.Value}} |
|||
} |
|||
} |
|||
} |
|||
|
|||
member := r.FieldByName(field.Name) |
|||
for _, elem := range elems { |
|||
err := parse(member, elem, field.Tag) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// parseList deserializes a list of values from an XML node. Each list entry
|
|||
// will also be deserialized.
|
|||
func parseList(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { |
|||
t := r.Type() |
|||
|
|||
if tag.Get("flattened") == "" { // look at all item entries
|
|||
mname := "member" |
|||
if name := tag.Get("locationNameList"); name != "" { |
|||
mname = name |
|||
} |
|||
|
|||
if Children, ok := node.Children[mname]; ok { |
|||
if r.IsNil() { |
|||
r.Set(reflect.MakeSlice(t, len(Children), len(Children))) |
|||
} |
|||
|
|||
for i, c := range Children { |
|||
err := parse(r.Index(i), c, "") |
|||
if err != nil { |
|||
return err |
|||
} |
|||
} |
|||
} |
|||
} else { // flattened list means this is a single element
|
|||
if r.IsNil() { |
|||
r.Set(reflect.MakeSlice(t, 0, 0)) |
|||
} |
|||
|
|||
childR := reflect.Zero(t.Elem()) |
|||
r.Set(reflect.Append(r, childR)) |
|||
err := parse(r.Index(r.Len()-1), node, "") |
|||
if err != nil { |
|||
return err |
|||
} |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// parseMap deserializes a map from an XMLNode. The direct children of the XMLNode
|
|||
// will also be deserialized as map entries.
|
|||
func parseMap(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { |
|||
if r.IsNil() { |
|||
r.Set(reflect.MakeMap(r.Type())) |
|||
} |
|||
|
|||
if tag.Get("flattened") == "" { // look at all child entries
|
|||
for _, entry := range node.Children["entry"] { |
|||
parseMapEntry(r, entry, tag) |
|||
} |
|||
} else { // this element is itself an entry
|
|||
parseMapEntry(r, node, tag) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// parseMapEntry deserializes a map entry from a XML node.
|
|||
func parseMapEntry(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { |
|||
kname, vname := "key", "value" |
|||
if n := tag.Get("locationNameKey"); n != "" { |
|||
kname = n |
|||
} |
|||
if n := tag.Get("locationNameValue"); n != "" { |
|||
vname = n |
|||
} |
|||
|
|||
keys, ok := node.Children[kname] |
|||
values := node.Children[vname] |
|||
if ok { |
|||
for i, key := range keys { |
|||
keyR := reflect.ValueOf(key.Text) |
|||
value := values[i] |
|||
valueR := reflect.New(r.Type().Elem()).Elem() |
|||
|
|||
parse(valueR, value, "") |
|||
r.SetMapIndex(keyR, valueR) |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// parseScaller deserializes an XMLNode value into a concrete type based on the
|
|||
// interface type of r.
|
|||
//
|
|||
// Error is returned if the deserialization fails due to invalid type conversion,
|
|||
// or unsupported interface type.
|
|||
func parseScalar(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { |
|||
switch r.Interface().(type) { |
|||
case *string: |
|||
r.Set(reflect.ValueOf(&node.Text)) |
|||
return nil |
|||
case []byte: |
|||
b, err := base64.StdEncoding.DecodeString(node.Text) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
r.Set(reflect.ValueOf(b)) |
|||
case *bool: |
|||
v, err := strconv.ParseBool(node.Text) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
r.Set(reflect.ValueOf(&v)) |
|||
case *int64: |
|||
v, err := strconv.ParseInt(node.Text, 10, 64) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
r.Set(reflect.ValueOf(&v)) |
|||
case *float64: |
|||
v, err := strconv.ParseFloat(node.Text, 64) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
r.Set(reflect.ValueOf(&v)) |
|||
case *time.Time: |
|||
const ISO8601UTC = "2006-01-02T15:04:05Z" |
|||
t, err := time.Parse(ISO8601UTC, node.Text) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
r.Set(reflect.ValueOf(&t)) |
|||
default: |
|||
return fmt.Errorf("unsupported value: %v (%s)", r.Interface(), r.Type()) |
|||
} |
|||
return nil |
|||
} |
@ -0,0 +1,105 @@ |
|||
package xmlutil |
|||
|
|||
import ( |
|||
"encoding/xml" |
|||
"io" |
|||
"sort" |
|||
) |
|||
|
|||
// A XMLNode contains the values to be encoded or decoded.
|
|||
type XMLNode struct { |
|||
Name xml.Name `json:",omitempty"` |
|||
Children map[string][]*XMLNode `json:",omitempty"` |
|||
Text string `json:",omitempty"` |
|||
Attr []xml.Attr `json:",omitempty"` |
|||
} |
|||
|
|||
// NewXMLElement returns a pointer to a new XMLNode initialized to default values.
|
|||
func NewXMLElement(name xml.Name) *XMLNode { |
|||
return &XMLNode{ |
|||
Name: name, |
|||
Children: map[string][]*XMLNode{}, |
|||
Attr: []xml.Attr{}, |
|||
} |
|||
} |
|||
|
|||
// AddChild adds child to the XMLNode.
|
|||
func (n *XMLNode) AddChild(child *XMLNode) { |
|||
if _, ok := n.Children[child.Name.Local]; !ok { |
|||
n.Children[child.Name.Local] = []*XMLNode{} |
|||
} |
|||
n.Children[child.Name.Local] = append(n.Children[child.Name.Local], child) |
|||
} |
|||
|
|||
// XMLToStruct converts a xml.Decoder stream to XMLNode with nested values.
|
|||
func XMLToStruct(d *xml.Decoder, s *xml.StartElement) (*XMLNode, error) { |
|||
out := &XMLNode{} |
|||
for { |
|||
tok, err := d.Token() |
|||
if tok == nil || err == io.EOF { |
|||
break |
|||
} |
|||
if err != nil { |
|||
return out, err |
|||
} |
|||
|
|||
switch typed := tok.(type) { |
|||
case xml.CharData: |
|||
out.Text = string(typed.Copy()) |
|||
case xml.StartElement: |
|||
el := typed.Copy() |
|||
out.Attr = el.Attr |
|||
if out.Children == nil { |
|||
out.Children = map[string][]*XMLNode{} |
|||
} |
|||
|
|||
name := typed.Name.Local |
|||
slice := out.Children[name] |
|||
if slice == nil { |
|||
slice = []*XMLNode{} |
|||
} |
|||
node, e := XMLToStruct(d, &el) |
|||
if e != nil { |
|||
return out, e |
|||
} |
|||
node.Name = typed.Name |
|||
slice = append(slice, node) |
|||
out.Children[name] = slice |
|||
case xml.EndElement: |
|||
if s != nil && s.Name.Local == typed.Name.Local { // matching end token
|
|||
return out, nil |
|||
} |
|||
} |
|||
} |
|||
return out, nil |
|||
} |
|||
|
|||
// StructToXML writes an XMLNode to a xml.Encoder as tokens.
|
|||
func StructToXML(e *xml.Encoder, node *XMLNode, sorted bool) error { |
|||
e.EncodeToken(xml.StartElement{Name: node.Name, Attr: node.Attr}) |
|||
|
|||
if node.Text != "" { |
|||
e.EncodeToken(xml.CharData([]byte(node.Text))) |
|||
} else if sorted { |
|||
sortedNames := []string{} |
|||
for k := range node.Children { |
|||
sortedNames = append(sortedNames, k) |
|||
} |
|||
sort.Strings(sortedNames) |
|||
|
|||
for _, k := range sortedNames { |
|||
for _, v := range node.Children[k] { |
|||
StructToXML(e, v, sorted) |
|||
} |
|||
} |
|||
} else { |
|||
for _, c := range node.Children { |
|||
for _, v := range c { |
|||
StructToXML(e, v, sorted) |
|||
} |
|||
} |
|||
} |
|||
|
|||
e.EncodeToken(xml.EndElement{Name: node.Name}) |
|||
return e.Flush() |
|||
} |
File diff suppressed because it is too large
@ -0,0 +1,12 @@ |
|||
package sts |
|||
|
|||
import "github.com/aws/aws-sdk-go/aws/request" |
|||
|
|||
func init() { |
|||
initRequest = func(r *request.Request) { |
|||
switch r.Operation.Name { |
|||
case opAssumeRoleWithSAML, opAssumeRoleWithWebIdentity: |
|||
r.Handlers.Sign.Clear() // these operations are unsigned
|
|||
} |
|||
} |
|||
} |
@ -0,0 +1,130 @@ |
|||
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
|
|||
|
|||
package sts |
|||
|
|||
import ( |
|||
"github.com/aws/aws-sdk-go/aws" |
|||
"github.com/aws/aws-sdk-go/aws/client" |
|||
"github.com/aws/aws-sdk-go/aws/client/metadata" |
|||
"github.com/aws/aws-sdk-go/aws/request" |
|||
"github.com/aws/aws-sdk-go/aws/signer/v4" |
|||
"github.com/aws/aws-sdk-go/private/protocol/query" |
|||
) |
|||
|
|||
// The AWS Security Token Service (STS) is a web service that enables you to
|
|||
// request temporary, limited-privilege credentials for AWS Identity and Access
|
|||
// Management (IAM) users or for users that you authenticate (federated users).
|
|||
// This guide provides descriptions of the STS API. For more detailed information
|
|||
// about using this service, go to Temporary Security Credentials (http://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp.html).
|
|||
//
|
|||
// As an alternative to using the API, you can use one of the AWS SDKs, which
|
|||
// consist of libraries and sample code for various programming languages and
|
|||
// platforms (Java, Ruby, .NET, iOS, Android, etc.). The SDKs provide a convenient
|
|||
// way to create programmatic access to STS. For example, the SDKs take care
|
|||
// of cryptographically signing requests, managing errors, and retrying requests
|
|||
// automatically. For information about the AWS SDKs, including how to download
|
|||
// and install them, see the Tools for Amazon Web Services page (http://aws.amazon.com/tools/).
|
|||
//
|
|||
// For information about setting up signatures and authorization through the
|
|||
// API, go to Signing AWS API Requests (http://docs.aws.amazon.com/general/latest/gr/signing_aws_api_requests.html)
|
|||
// in the AWS General Reference. For general information about the Query API,
|
|||
// go to Making Query Requests (http://docs.aws.amazon.com/IAM/latest/UserGuide/IAM_UsingQueryAPI.html)
|
|||
// in Using IAM. For information about using security tokens with other AWS
|
|||
// products, go to AWS Services That Work with IAM (http://docs.aws.amazon.com/IAM/latest/UserGuide/reference_aws-services-that-work-with-iam.html)
|
|||
// in the IAM User Guide.
|
|||
//
|
|||
// If you're new to AWS and need additional technical information about a specific
|
|||
// AWS product, you can find the product's technical documentation at http://aws.amazon.com/documentation/
|
|||
// (http://aws.amazon.com/documentation/).
|
|||
//
|
|||
// Endpoints
|
|||
//
|
|||
// The AWS Security Token Service (STS) has a default endpoint of https://sts.amazonaws.com
|
|||
// that maps to the US East (N. Virginia) region. Additional regions are available
|
|||
// and are activated by default. For more information, see Activating and Deactivating
|
|||
// AWS STS in an AWS Region (http://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html)
|
|||
// in the IAM User Guide.
|
|||
//
|
|||
// For information about STS endpoints, see Regions and Endpoints (http://docs.aws.amazon.com/general/latest/gr/rande.html#sts_region)
|
|||
// in the AWS General Reference.
|
|||
//
|
|||
// Recording API requests
|
|||
//
|
|||
// STS supports AWS CloudTrail, which is a service that records AWS calls for
|
|||
// your AWS account and delivers log files to an Amazon S3 bucket. By using
|
|||
// information collected by CloudTrail, you can determine what requests were
|
|||
// successfully made to STS, who made the request, when it was made, and so
|
|||
// on. To learn more about CloudTrail, including how to turn it on and find
|
|||
// your log files, see the AWS CloudTrail User Guide (http://docs.aws.amazon.com/awscloudtrail/latest/userguide/what_is_cloud_trail_top_level.html).
|
|||
//The service client's operations are safe to be used concurrently.
|
|||
// It is not safe to mutate any of the client's properties though.
|
|||
type STS struct { |
|||
*client.Client |
|||
} |
|||
|
|||
// Used for custom client initialization logic
|
|||
var initClient func(*client.Client) |
|||
|
|||
// Used for custom request initialization logic
|
|||
var initRequest func(*request.Request) |
|||
|
|||
// A ServiceName is the name of the service the client will make API calls to.
|
|||
const ServiceName = "sts" |
|||
|
|||
// New creates a new instance of the STS client with a session.
|
|||
// If additional configuration is needed for the client instance use the optional
|
|||
// aws.Config parameter to add your extra config.
|
|||
//
|
|||
// Example:
|
|||
// // Create a STS client from just a session.
|
|||
// svc := sts.New(mySession)
|
|||
//
|
|||
// // Create a STS client with additional configuration
|
|||
// svc := sts.New(mySession, aws.NewConfig().WithRegion("us-west-2"))
|
|||
func New(p client.ConfigProvider, cfgs ...*aws.Config) *STS { |
|||
c := p.ClientConfig(ServiceName, cfgs...) |
|||
return newClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion) |
|||
} |
|||
|
|||
// newClient creates, initializes and returns a new service client instance.
|
|||
func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion string) *STS { |
|||
svc := &STS{ |
|||
Client: client.New( |
|||
cfg, |
|||
metadata.ClientInfo{ |
|||
ServiceName: ServiceName, |
|||
SigningRegion: signingRegion, |
|||
Endpoint: endpoint, |
|||
APIVersion: "2011-06-15", |
|||
}, |
|||
handlers, |
|||
), |
|||
} |
|||
|
|||
// Handlers
|
|||
svc.Handlers.Sign.PushBackNamed(v4.SignRequestHandler) |
|||
svc.Handlers.Build.PushBackNamed(query.BuildHandler) |
|||
svc.Handlers.Unmarshal.PushBackNamed(query.UnmarshalHandler) |
|||
svc.Handlers.UnmarshalMeta.PushBackNamed(query.UnmarshalMetaHandler) |
|||
svc.Handlers.UnmarshalError.PushBackNamed(query.UnmarshalErrorHandler) |
|||
|
|||
// Run custom client initialization if present
|
|||
if initClient != nil { |
|||
initClient(svc.Client) |
|||
} |
|||
|
|||
return svc |
|||
} |
|||
|
|||
// newRequest creates a new request for a STS operation and runs any
|
|||
// custom request initialization.
|
|||
func (c *STS) newRequest(op *request.Operation, params, data interface{}) *request.Request { |
|||
req := c.NewRequest(op, params, data) |
|||
|
|||
// Run custom request initialization if present
|
|||
if initRequest != nil { |
|||
initRequest(req) |
|||
} |
|||
|
|||
return req |
|||
} |
@ -0,0 +1,191 @@ |
|||
Apache License |
|||
Version 2.0, January 2004 |
|||
http://www.apache.org/licenses/ |
|||
|
|||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION |
|||
|
|||
1. Definitions. |
|||
|
|||
"License" shall mean the terms and conditions for use, reproduction, and |
|||
distribution as defined by Sections 1 through 9 of this document. |
|||
|
|||
"Licensor" shall mean the copyright owner or entity authorized by the copyright |
|||
owner that is granting the License. |
|||
|
|||
"Legal Entity" shall mean the union of the acting entity and all other entities |
|||
that control, are controlled by, or are under common control with that entity. |
|||
For the purposes of this definition, "control" means (i) the power, direct or |
|||
indirect, to cause the direction or management of such entity, whether by |
|||
contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the |
|||
outstanding shares, or (iii) beneficial ownership of such entity. |
|||
|
|||
"You" (or "Your") shall mean an individual or Legal Entity exercising |
|||
permissions granted by this License. |
|||
|
|||
"Source" form shall mean the preferred form for making modifications, including |
|||
but not limited to software source code, documentation source, and configuration |
|||
files. |
|||
|
|||
"Object" form shall mean any form resulting from mechanical transformation or |
|||
translation of a Source form, including but not limited to compiled object code, |
|||
generated documentation, and conversions to other media types. |
|||
|
|||
"Work" shall mean the work of authorship, whether in Source or Object form, made |
|||
available under the License, as indicated by a copyright notice that is included |
|||
in or attached to the work (an example is provided in the Appendix below). |
|||
|
|||
"Derivative Works" shall mean any work, whether in Source or Object form, that |
|||
is based on (or derived from) the Work and for which the editorial revisions, |
|||
annotations, elaborations, or other modifications represent, as a whole, an |
|||
original work of authorship. For the purposes of this License, Derivative Works |
|||
shall not include works that remain separable from, or merely link (or bind by |
|||
name) to the interfaces of, the Work and Derivative Works thereof. |
|||
|
|||
"Contribution" shall mean any work of authorship, including the original version |
|||
of the Work and any modifications or additions to that Work or Derivative Works |
|||
thereof, that is intentionally submitted to Licensor for inclusion in the Work |
|||
by the copyright owner or by an individual or Legal Entity authorized to submit |
|||
on behalf of the copyright owner. For the purposes of this definition, |
|||
"submitted" means any form of electronic, verbal, or written communication sent |
|||
to the Licensor or its representatives, including but not limited to |
|||
communication on electronic mailing lists, source code control systems, and |
|||
issue tracking systems that are managed by, or on behalf of, the Licensor for |
|||
the purpose of discussing and improving the Work, but excluding communication |
|||
that is conspicuously marked or otherwise designated in writing by the copyright |
|||
owner as "Not a Contribution." |
|||
|
|||
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf |
|||
of whom a Contribution has been received by Licensor and subsequently |
|||
incorporated within the Work. |
|||
|
|||
2. Grant of Copyright License. |
|||
|
|||
Subject to the terms and conditions of this License, each Contributor hereby |
|||
grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, |
|||
irrevocable copyright license to reproduce, prepare Derivative Works of, |
|||
publicly display, publicly perform, sublicense, and distribute the Work and such |
|||
Derivative Works in Source or Object form. |
|||
|
|||
3. Grant of Patent License. |
|||
|
|||
Subject to the terms and conditions of this License, each Contributor hereby |
|||
grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, |
|||
irrevocable (except as stated in this section) patent license to make, have |
|||
made, use, offer to sell, sell, import, and otherwise transfer the Work, where |
|||
such license applies only to those patent claims licensable by such Contributor |
|||
that are necessarily infringed by their Contribution(s) alone or by combination |
|||
of their Contribution(s) with the Work to which such Contribution(s) was |
|||
submitted. If You institute patent litigation against any entity (including a |
|||
cross-claim or counterclaim in a lawsuit) alleging that the Work or a |
|||
Contribution incorporated within the Work constitutes direct or contributory |
|||
patent infringement, then any patent licenses granted to You under this License |
|||
for that Work shall terminate as of the date such litigation is filed. |
|||
|
|||
4. Redistribution. |
|||
|
|||
You may reproduce and distribute copies of the Work or Derivative Works thereof |
|||
in any medium, with or without modifications, and in Source or Object form, |
|||
provided that You meet the following conditions: |
|||
|
|||
You must give any other recipients of the Work or Derivative Works a copy of |
|||
this License; and |
|||
You must cause any modified files to carry prominent notices stating that You |
|||
changed the files; and |
|||
You must retain, in the Source form of any Derivative Works that You distribute, |
|||
all copyright, patent, trademark, and attribution notices from the Source form |
|||
of the Work, excluding those notices that do not pertain to any part of the |
|||
Derivative Works; and |
|||
If the Work includes a "NOTICE" text file as part of its distribution, then any |
|||
Derivative Works that You distribute must include a readable copy of the |
|||
attribution notices contained within such NOTICE file, excluding those notices |
|||
that do not pertain to any part of the Derivative Works, in at least one of the |
|||
following places: within a NOTICE text file distributed as part of the |
|||
Derivative Works; within the Source form or documentation, if provided along |
|||
with the Derivative Works; or, within a display generated by the Derivative |
|||
Works, if and wherever such third-party notices normally appear. The contents of |
|||
the NOTICE file are for informational purposes only and do not modify the |
|||
License. You may add Your own attribution notices within Derivative Works that |
|||
You distribute, alongside or as an addendum to the NOTICE text from the Work, |
|||
provided that such additional attribution notices cannot be construed as |
|||
modifying the License. |
|||
You may add Your own copyright statement to Your modifications and may provide |
|||
additional or different license terms and conditions for use, reproduction, or |
|||
distribution of Your modifications, or for any such Derivative Works as a whole, |
|||
provided Your use, reproduction, and distribution of the Work otherwise complies |
|||
with the conditions stated in this License. |
|||
|
|||
5. Submission of Contributions. |
|||
|
|||
Unless You explicitly state otherwise, any Contribution intentionally submitted |
|||
for inclusion in the Work by You to the Licensor shall be under the terms and |
|||
conditions of this License, without any additional terms or conditions. |
|||
Notwithstanding the above, nothing herein shall supersede or modify the terms of |
|||
any separate license agreement you may have executed with Licensor regarding |
|||
such Contributions. |
|||
|
|||
6. Trademarks. |
|||
|
|||
This License does not grant permission to use the trade names, trademarks, |
|||
service marks, or product names of the Licensor, except as required for |
|||
reasonable and customary use in describing the origin of the Work and |
|||
reproducing the content of the NOTICE file. |
|||
|
|||
7. Disclaimer of Warranty. |
|||
|
|||
Unless required by applicable law or agreed to in writing, Licensor provides the |
|||
Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, |
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, |
|||
including, without limitation, any warranties or conditions of TITLE, |
|||
NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are |
|||
solely responsible for determining the appropriateness of using or |
|||
redistributing the Work and assume any risks associated with Your exercise of |
|||
permissions under this License. |
|||
|
|||
8. Limitation of Liability. |
|||
|
|||
In no event and under no legal theory, whether in tort (including negligence), |
|||
contract, or otherwise, unless required by applicable law (such as deliberate |
|||
and grossly negligent acts) or agreed to in writing, shall any Contributor be |
|||
liable to You for damages, including any direct, indirect, special, incidental, |
|||
or consequential damages of any character arising as a result of this License or |
|||
out of the use or inability to use the Work (including but not limited to |
|||
damages for loss of goodwill, work stoppage, computer failure or malfunction, or |
|||
any and all other commercial damages or losses), even if such Contributor has |
|||
been advised of the possibility of such damages. |
|||
|
|||
9. Accepting Warranty or Additional Liability. |
|||
|
|||
While redistributing the Work or Derivative Works thereof, You may choose to |
|||
offer, and charge a fee for, acceptance of support, warranty, indemnity, or |
|||
other liability obligations and/or rights consistent with this License. However, |
|||
in accepting such obligations, You may act only on Your own behalf and on Your |
|||
sole responsibility, not on behalf of any other Contributor, and only if You |
|||
agree to indemnify, defend, and hold each Contributor harmless for any liability |
|||
incurred by, or claims asserted against, such Contributor by reason of your |
|||
accepting any such warranty or additional liability. |
|||
|
|||
END OF TERMS AND CONDITIONS |
|||
|
|||
APPENDIX: How to apply the Apache License to your work |
|||
|
|||
To apply the Apache License to your work, attach the following boilerplate |
|||
notice, with the fields enclosed by brackets "[]" replaced with your own |
|||
identifying information. (Don't include the brackets!) The text should be |
|||
enclosed in the appropriate comment syntax for the file format. We also |
|||
recommend that a file or class name and description of purpose be included on |
|||
the same "printed page" as the copyright notice for easier identification within |
|||
third-party archives. |
|||
|
|||
Copyright [yyyy] [name of copyright owner] |
|||
|
|||
Licensed under the Apache License, Version 2.0 (the "License"); |
|||
you may not use this file except in compliance with the License. |
|||
You may obtain a copy of the License at |
|||
|
|||
http://www.apache.org/licenses/LICENSE-2.0 |
|||
|
|||
Unless required by applicable law or agreed to in writing, software |
|||
distributed under the License is distributed on an "AS IS" BASIS, |
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
See the License for the specific language governing permissions and |
|||
limitations under the License. |
@ -0,0 +1,12 @@ |
|||
.PHONY: build test bench vet |
|||
|
|||
build: vet bench |
|||
|
|||
test: |
|||
go test -v -cover -race |
|||
|
|||
bench: |
|||
go test -v -cover -race -test.bench=. -test.benchmem |
|||
|
|||
vet: |
|||
go vet |
@ -0,0 +1,703 @@ |
|||
INI [![Build Status](https://travis-ci.org/go-ini/ini.svg?branch=master)](https://travis-ci.org/go-ini/ini) |
|||
=== |
|||
|
|||
![](https://avatars0.githubusercontent.com/u/10216035?v=3&s=200) |
|||
|
|||
Package ini provides INI file read and write functionality in Go. |
|||
|
|||
[简体中文](README_ZH.md) |
|||
|
|||
## Feature |
|||
|
|||
- Load multiple data sources(`[]byte` or file) with overwrites. |
|||
- Read with recursion values. |
|||
- Read with parent-child sections. |
|||
- Read with auto-increment key names. |
|||
- Read with multiple-line values. |
|||
- Read with tons of helper methods. |
|||
- Read and convert values to Go types. |
|||
- Read and **WRITE** comments of sections and keys. |
|||
- Manipulate sections, keys and comments with ease. |
|||
- Keep sections and keys in order as you parse and save. |
|||
|
|||
## Installation |
|||
|
|||
To use a tagged revision: |
|||
|
|||
go get gopkg.in/ini.v1 |
|||
|
|||
To use with latest changes: |
|||
|
|||
go get github.com/go-ini/ini |
|||
|
|||
Please add `-u` flag to update in the future. |
|||
|
|||
### Testing |
|||
|
|||
If you want to test on your machine, please apply `-t` flag: |
|||
|
|||
go get -t gopkg.in/ini.v1 |
|||
|
|||
Please add `-u` flag to update in the future. |
|||
|
|||
## Getting Started |
|||
|
|||
### Loading from data sources |
|||
|
|||
A **Data Source** is either raw data in type `[]byte` or a file name with type `string` and you can load **as many data sources as you want**. Passing other types will simply return an error. |
|||
|
|||
```go |
|||
cfg, err := ini.Load([]byte("raw data"), "filename") |
|||
``` |
|||
|
|||
Or start with an empty object: |
|||
|
|||
```go |
|||
cfg := ini.Empty() |
|||
``` |
|||
|
|||
When you cannot decide how many data sources to load at the beginning, you will still be able to **Append()** them later. |
|||
|
|||
```go |
|||
err := cfg.Append("other file", []byte("other raw data")) |
|||
``` |
|||
|
|||
If you have a list of files with possibilities that some of them may not available at the time, and you don't know exactly which ones, you can use `LooseLoad` to ignore nonexistent files without returning error. |
|||
|
|||
```go |
|||
cfg, err := ini.LooseLoad("filename", "filename_404") |
|||
``` |
|||
|
|||
The cool thing is, whenever the file is available to load while you're calling `Reload` method, it will be counted as usual. |
|||
|
|||
#### Ignore cases of key name |
|||
|
|||
When you do not care about cases of section and key names, you can use `InsensitiveLoad` to force all names to be lowercased while parsing. |
|||
|
|||
```go |
|||
cfg, err := ini.InsensitiveLoad("filename") |
|||
//... |
|||
|
|||
// sec1 and sec2 are the exactly same section object |
|||
sec1, err := cfg.GetSection("Section") |
|||
sec2, err := cfg.GetSection("SecTIOn") |
|||
|
|||
// key1 and key2 are the exactly same key object |
|||
key1, err := cfg.GetKey("Key") |
|||
key2, err := cfg.GetKey("KeY") |
|||
``` |
|||
|
|||
#### MySQL-like boolean key |
|||
|
|||
MySQL's configuration allows a key without value as follows: |
|||
|
|||
```ini |
|||
[mysqld] |
|||
... |
|||
skip-host-cache |
|||
skip-name-resolve |
|||
``` |
|||
|
|||
By default, this is considered as missing value. But if you know you're going to deal with those cases, you can assign advanced load options: |
|||
|
|||
```go |
|||
cfg, err := LoadSources(LoadOptions{AllowBooleanKeys: true}, "my.cnf")) |
|||
``` |
|||
|
|||
The value of those keys are always `true`, and when you save to a file, it will keep in the same foramt as you read. |
|||
|
|||
### Working with sections |
|||
|
|||
To get a section, you would need to: |
|||
|
|||
```go |
|||
section, err := cfg.GetSection("section name") |
|||
``` |
|||
|
|||
For a shortcut for default section, just give an empty string as name: |
|||
|
|||
```go |
|||
section, err := cfg.GetSection("") |
|||
``` |
|||
|
|||
When you're pretty sure the section exists, following code could make your life easier: |
|||
|
|||
```go |
|||
section := cfg.Section("") |
|||
``` |
|||
|
|||
What happens when the section somehow does not exist? Don't panic, it automatically creates and returns a new section to you. |
|||
|
|||
To create a new section: |
|||
|
|||
```go |
|||
err := cfg.NewSection("new section") |
|||
``` |
|||
|
|||
To get a list of sections or section names: |
|||
|
|||
```go |
|||
sections := cfg.Sections() |
|||
names := cfg.SectionStrings() |
|||
``` |
|||
|
|||
### Working with keys |
|||
|
|||
To get a key under a section: |
|||
|
|||
```go |
|||
key, err := cfg.Section("").GetKey("key name") |
|||
``` |
|||
|
|||
Same rule applies to key operations: |
|||
|
|||
```go |
|||
key := cfg.Section("").Key("key name") |
|||
``` |
|||
|
|||
To check if a key exists: |
|||
|
|||
```go |
|||
yes := cfg.Section("").HasKey("key name") |
|||
``` |
|||
|
|||
To create a new key: |
|||
|
|||
```go |
|||
err := cfg.Section("").NewKey("name", "value") |
|||
``` |
|||
|
|||
To get a list of keys or key names: |
|||
|
|||
```go |
|||
keys := cfg.Section("").Keys() |
|||
names := cfg.Section("").KeyStrings() |
|||
``` |
|||
|
|||
To get a clone hash of keys and corresponding values: |
|||
|
|||
```go |
|||
hash := cfg.Section("").KeysHash() |
|||
``` |
|||
|
|||
### Working with values |
|||
|
|||
To get a string value: |
|||
|
|||
```go |
|||
val := cfg.Section("").Key("key name").String() |
|||
``` |
|||
|
|||
To validate key value on the fly: |
|||
|
|||
```go |
|||
val := cfg.Section("").Key("key name").Validate(func(in string) string { |
|||
if len(in) == 0 { |
|||
return "default" |
|||
} |
|||
return in |
|||
}) |
|||
``` |
|||
|
|||
If you do not want any auto-transformation (such as recursive read) for the values, you can get raw value directly (this way you get much better performance): |
|||
|
|||
```go |
|||
val := cfg.Section("").Key("key name").Value() |
|||
``` |
|||
|
|||
To check if raw value exists: |
|||
|
|||
```go |
|||
yes := cfg.Section("").HasValue("test value") |
|||
``` |
|||
|
|||
To get value with types: |
|||
|
|||
```go |
|||
// For boolean values: |
|||
// true when value is: 1, t, T, TRUE, true, True, YES, yes, Yes, y, ON, on, On |
|||
// false when value is: 0, f, F, FALSE, false, False, NO, no, No, n, OFF, off, Off |
|||
v, err = cfg.Section("").Key("BOOL").Bool() |
|||
v, err = cfg.Section("").Key("FLOAT64").Float64() |
|||
v, err = cfg.Section("").Key("INT").Int() |
|||
v, err = cfg.Section("").Key("INT64").Int64() |
|||
v, err = cfg.Section("").Key("UINT").Uint() |
|||
v, err = cfg.Section("").Key("UINT64").Uint64() |
|||
v, err = cfg.Section("").Key("TIME").TimeFormat(time.RFC3339) |
|||
v, err = cfg.Section("").Key("TIME").Time() // RFC3339 |
|||
|
|||
v = cfg.Section("").Key("BOOL").MustBool() |
|||
v = cfg.Section("").Key("FLOAT64").MustFloat64() |
|||
v = cfg.Section("").Key("INT").MustInt() |
|||
v = cfg.Section("").Key("INT64").MustInt64() |
|||
v = cfg.Section("").Key("UINT").MustUint() |
|||
v = cfg.Section("").Key("UINT64").MustUint64() |
|||
v = cfg.Section("").Key("TIME").MustTimeFormat(time.RFC3339) |
|||
v = cfg.Section("").Key("TIME").MustTime() // RFC3339 |
|||
|
|||
// Methods start with Must also accept one argument for default value |
|||
// when key not found or fail to parse value to given type. |
|||
// Except method MustString, which you have to pass a default value. |
|||
|
|||
v = cfg.Section("").Key("String").MustString("default") |
|||
v = cfg.Section("").Key("BOOL").MustBool(true) |
|||
v = cfg.Section("").Key("FLOAT64").MustFloat64(1.25) |
|||
v = cfg.Section("").Key("INT").MustInt(10) |
|||
v = cfg.Section("").Key("INT64").MustInt64(99) |
|||
v = cfg.Section("").Key("UINT").MustUint(3) |
|||
v = cfg.Section("").Key("UINT64").MustUint64(6) |
|||
v = cfg.Section("").Key("TIME").MustTimeFormat(time.RFC3339, time.Now()) |
|||
v = cfg.Section("").Key("TIME").MustTime(time.Now()) // RFC3339 |
|||
``` |
|||
|
|||
What if my value is three-line long? |
|||
|
|||
```ini |
|||
[advance] |
|||
ADDRESS = """404 road, |
|||
NotFound, State, 5000 |
|||
Earth""" |
|||
``` |
|||
|
|||
Not a problem! |
|||
|
|||
```go |
|||
cfg.Section("advance").Key("ADDRESS").String() |
|||
|
|||
/* --- start --- |
|||
404 road, |
|||
NotFound, State, 5000 |
|||
Earth |
|||
------ end --- */ |
|||
``` |
|||
|
|||
That's cool, how about continuation lines? |
|||
|
|||
```ini |
|||
[advance] |
|||
two_lines = how about \ |
|||
continuation lines? |
|||
lots_of_lines = 1 \ |
|||
2 \ |
|||
3 \ |
|||
4 |
|||
``` |
|||
|
|||
Piece of cake! |
|||
|
|||
```go |
|||
cfg.Section("advance").Key("two_lines").String() // how about continuation lines? |
|||
cfg.Section("advance").Key("lots_of_lines").String() // 1 2 3 4 |
|||
``` |
|||
|
|||
Well, I hate continuation lines, how do I disable that? |
|||
|
|||
```go |
|||
cfg, err := ini.LoadSources(ini.LoadOptions{ |
|||
IgnoreContinuation: true, |
|||
}, "filename") |
|||
``` |
|||
|
|||
Holy crap! |
|||
|
|||
Note that single quotes around values will be stripped: |
|||
|
|||
```ini |
|||
foo = "some value" // foo: some value |
|||
bar = 'some value' // bar: some value |
|||
``` |
|||
|
|||
That's all? Hmm, no. |
|||
|
|||
#### Helper methods of working with values |
|||
|
|||
To get value with given candidates: |
|||
|
|||
```go |
|||
v = cfg.Section("").Key("STRING").In("default", []string{"str", "arr", "types"}) |
|||
v = cfg.Section("").Key("FLOAT64").InFloat64(1.1, []float64{1.25, 2.5, 3.75}) |
|||
v = cfg.Section("").Key("INT").InInt(5, []int{10, 20, 30}) |
|||
v = cfg.Section("").Key("INT64").InInt64(10, []int64{10, 20, 30}) |
|||
v = cfg.Section("").Key("UINT").InUint(4, []int{3, 6, 9}) |
|||
v = cfg.Section("").Key("UINT64").InUint64(8, []int64{3, 6, 9}) |
|||
v = cfg.Section("").Key("TIME").InTimeFormat(time.RFC3339, time.Now(), []time.Time{time1, time2, time3}) |
|||
v = cfg.Section("").Key("TIME").InTime(time.Now(), []time.Time{time1, time2, time3}) // RFC3339 |
|||
``` |
|||
|
|||
Default value will be presented if value of key is not in candidates you given, and default value does not need be one of candidates. |
|||
|
|||
To validate value in a given range: |
|||
|
|||
```go |
|||
vals = cfg.Section("").Key("FLOAT64").RangeFloat64(0.0, 1.1, 2.2) |
|||
vals = cfg.Section("").Key("INT").RangeInt(0, 10, 20) |
|||
vals = cfg.Section("").Key("INT64").RangeInt64(0, 10, 20) |
|||
vals = cfg.Section("").Key("UINT").RangeUint(0, 3, 9) |
|||
vals = cfg.Section("").Key("UINT64").RangeUint64(0, 3, 9) |
|||
vals = cfg.Section("").Key("TIME").RangeTimeFormat(time.RFC3339, time.Now(), minTime, maxTime) |
|||
vals = cfg.Section("").Key("TIME").RangeTime(time.Now(), minTime, maxTime) // RFC3339 |
|||
``` |
|||
|
|||
##### Auto-split values into a slice |
|||
|
|||
To use zero value of type for invalid inputs: |
|||
|
|||
```go |
|||
// Input: 1.1, 2.2, 3.3, 4.4 -> [1.1 2.2 3.3 4.4] |
|||
// Input: how, 2.2, are, you -> [0.0 2.2 0.0 0.0] |
|||
vals = cfg.Section("").Key("STRINGS").Strings(",") |
|||
vals = cfg.Section("").Key("FLOAT64S").Float64s(",") |
|||
vals = cfg.Section("").Key("INTS").Ints(",") |
|||
vals = cfg.Section("").Key("INT64S").Int64s(",") |
|||
vals = cfg.Section("").Key("UINTS").Uints(",") |
|||
vals = cfg.Section("").Key("UINT64S").Uint64s(",") |
|||
vals = cfg.Section("").Key("TIMES").Times(",") |
|||
``` |
|||
|
|||
To exclude invalid values out of result slice: |
|||
|
|||
```go |
|||
// Input: 1.1, 2.2, 3.3, 4.4 -> [1.1 2.2 3.3 4.4] |
|||
// Input: how, 2.2, are, you -> [2.2] |
|||
vals = cfg.Section("").Key("FLOAT64S").ValidFloat64s(",") |
|||
vals = cfg.Section("").Key("INTS").ValidInts(",") |
|||
vals = cfg.Section("").Key("INT64S").ValidInt64s(",") |
|||
vals = cfg.Section("").Key("UINTS").ValidUints(",") |
|||
vals = cfg.Section("").Key("UINT64S").ValidUint64s(",") |
|||
vals = cfg.Section("").Key("TIMES").ValidTimes(",") |
|||
``` |
|||
|
|||
Or to return nothing but error when have invalid inputs: |
|||
|
|||
```go |
|||
// Input: 1.1, 2.2, 3.3, 4.4 -> [1.1 2.2 3.3 4.4] |
|||
// Input: how, 2.2, are, you -> error |
|||
vals = cfg.Section("").Key("FLOAT64S").StrictFloat64s(",") |
|||
vals = cfg.Section("").Key("INTS").StrictInts(",") |
|||
vals = cfg.Section("").Key("INT64S").StrictInt64s(",") |
|||
vals = cfg.Section("").Key("UINTS").StrictUints(",") |
|||
vals = cfg.Section("").Key("UINT64S").StrictUint64s(",") |
|||
vals = cfg.Section("").Key("TIMES").StrictTimes(",") |
|||
``` |
|||
|
|||
### Save your configuration |
|||
|
|||
Finally, it's time to save your configuration to somewhere. |
|||
|
|||
A typical way to save configuration is writing it to a file: |
|||
|
|||
```go |
|||
// ... |
|||
err = cfg.SaveTo("my.ini") |
|||
err = cfg.SaveToIndent("my.ini", "\t") |
|||
``` |
|||
|
|||
Another way to save is writing to a `io.Writer` interface: |
|||
|
|||
```go |
|||
// ... |
|||
cfg.WriteTo(writer) |
|||
cfg.WriteToIndent(writer, "\t") |
|||
``` |
|||
|
|||
## Advanced Usage |
|||
|
|||
### Recursive Values |
|||
|
|||
For all value of keys, there is a special syntax `%(<name>)s`, where `<name>` is the key name in same section or default section, and `%(<name>)s` will be replaced by corresponding value(empty string if key not found). You can use this syntax at most 99 level of recursions. |
|||
|
|||
```ini |
|||
NAME = ini |
|||
|
|||
[author] |
|||
NAME = Unknwon |
|||
GITHUB = https://github.com/%(NAME)s |
|||
|
|||
[package] |
|||
FULL_NAME = github.com/go-ini/%(NAME)s |
|||
``` |
|||
|
|||
```go |
|||
cfg.Section("author").Key("GITHUB").String() // https://github.com/Unknwon |
|||
cfg.Section("package").Key("FULL_NAME").String() // github.com/go-ini/ini |
|||
``` |
|||
|
|||
### Parent-child Sections |
|||
|
|||
You can use `.` in section name to indicate parent-child relationship between two or more sections. If the key not found in the child section, library will try again on its parent section until there is no parent section. |
|||
|
|||
```ini |
|||
NAME = ini |
|||
VERSION = v1 |
|||
IMPORT_PATH = gopkg.in/%(NAME)s.%(VERSION)s |
|||
|
|||
[package] |
|||
CLONE_URL = https://%(IMPORT_PATH)s |
|||
|
|||
[package.sub] |
|||
``` |
|||
|
|||
```go |
|||
cfg.Section("package.sub").Key("CLONE_URL").String() // https://gopkg.in/ini.v1 |
|||
``` |
|||
|
|||
#### Retrieve parent keys available to a child section |
|||
|
|||
```go |
|||
cfg.Section("package.sub").ParentKeys() // ["CLONE_URL"] |
|||
``` |
|||
|
|||
### Auto-increment Key Names |
|||
|
|||
If key name is `-` in data source, then it would be seen as special syntax for auto-increment key name start from 1, and every section is independent on counter. |
|||
|
|||
```ini |
|||
[features] |
|||
-: Support read/write comments of keys and sections |
|||
-: Support auto-increment of key names |
|||
-: Support load multiple files to overwrite key values |
|||
``` |
|||
|
|||
```go |
|||
cfg.Section("features").KeyStrings() // []{"#1", "#2", "#3"} |
|||
``` |
|||
|
|||
### Map To Struct |
|||
|
|||
Want more objective way to play with INI? Cool. |
|||
|
|||
```ini |
|||
Name = Unknwon |
|||
age = 21 |
|||
Male = true |
|||
Born = 1993-01-01T20:17:05Z |
|||
|
|||
[Note] |
|||
Content = Hi is a good man! |
|||
Cities = HangZhou, Boston |
|||
``` |
|||
|
|||
```go |
|||
type Note struct { |
|||
Content string |
|||
Cities []string |
|||
} |
|||
|
|||
type Person struct { |
|||
Name string |
|||
Age int `ini:"age"` |
|||
Male bool |
|||
Born time.Time |
|||
Note |
|||
Created time.Time `ini:"-"` |
|||
} |
|||
|
|||
func main() { |
|||
cfg, err := ini.Load("path/to/ini") |
|||
// ... |
|||
p := new(Person) |
|||
err = cfg.MapTo(p) |
|||
// ... |
|||
|
|||
// Things can be simpler. |
|||
err = ini.MapTo(p, "path/to/ini") |
|||
// ... |
|||
|
|||
// Just map a section? Fine. |
|||
n := new(Note) |
|||
err = cfg.Section("Note").MapTo(n) |
|||
// ... |
|||
} |
|||
``` |
|||
|
|||
Can I have default value for field? Absolutely. |
|||
|
|||
Assign it before you map to struct. It will keep the value as it is if the key is not presented or got wrong type. |
|||
|
|||
```go |
|||
// ... |
|||
p := &Person{ |
|||
Name: "Joe", |
|||
} |
|||
// ... |
|||
``` |
|||
|
|||
It's really cool, but what's the point if you can't give me my file back from struct? |
|||
|
|||
### Reflect From Struct |
|||
|
|||
Why not? |
|||
|
|||
```go |
|||
type Embeded struct { |
|||
Dates []time.Time `delim:"|"` |
|||
Places []string `ini:"places,omitempty"` |
|||
None []int `ini:",omitempty"` |
|||
} |
|||
|
|||
type Author struct { |
|||
Name string `ini:"NAME"` |
|||
Male bool |
|||
Age int |
|||
GPA float64 |
|||
NeverMind string `ini:"-"` |
|||
*Embeded |
|||
} |
|||
|
|||
func main() { |
|||
a := &Author{"Unknwon", true, 21, 2.8, "", |
|||
&Embeded{ |
|||
[]time.Time{time.Now(), time.Now()}, |
|||
[]string{"HangZhou", "Boston"}, |
|||
[]int{}, |
|||
}} |
|||
cfg := ini.Empty() |
|||
err = ini.ReflectFrom(cfg, a) |
|||
// ... |
|||
} |
|||
``` |
|||
|
|||
So, what do I get? |
|||
|
|||
```ini |
|||
NAME = Unknwon |
|||
Male = true |
|||
Age = 21 |
|||
GPA = 2.8 |
|||
|
|||
[Embeded] |
|||
Dates = 2015-08-07T22:14:22+08:00|2015-08-07T22:14:22+08:00 |
|||
places = HangZhou,Boston |
|||
``` |
|||
|
|||
#### Name Mapper |
|||
|
|||
To save your time and make your code cleaner, this library supports [`NameMapper`](https://gowalker.org/gopkg.in/ini.v1#NameMapper) between struct field and actual section and key name. |
|||
|
|||
There are 2 built-in name mappers: |
|||
|
|||
- `AllCapsUnderscore`: it converts to format `ALL_CAPS_UNDERSCORE` then match section or key. |
|||
- `TitleUnderscore`: it converts to format `title_underscore` then match section or key. |
|||
|
|||
To use them: |
|||
|
|||
```go |
|||
type Info struct { |
|||
PackageName string |
|||
} |
|||
|
|||
func main() { |
|||
err = ini.MapToWithMapper(&Info{}, ini.TitleUnderscore, []byte("package_name=ini")) |
|||
// ... |
|||
|
|||
cfg, err := ini.Load([]byte("PACKAGE_NAME=ini")) |
|||
// ... |
|||
info := new(Info) |
|||
cfg.NameMapper = ini.AllCapsUnderscore |
|||
err = cfg.MapTo(info) |
|||
// ... |
|||
} |
|||
``` |
|||
|
|||
Same rules of name mapper apply to `ini.ReflectFromWithMapper` function. |
|||
|
|||
#### Value Mapper |
|||
|
|||
To expand values (e.g. from environment variables), you can use the `ValueMapper` to transform values: |
|||
|
|||
```go |
|||
type Env struct { |
|||
Foo string `ini:"foo"` |
|||
} |
|||
|
|||
func main() { |
|||
cfg, err := ini.Load([]byte("[env]\nfoo = ${MY_VAR}\n") |
|||
cfg.ValueMapper = os.ExpandEnv |
|||
// ... |
|||
env := &Env{} |
|||
err = cfg.Section("env").MapTo(env) |
|||
} |
|||
``` |
|||
|
|||
This would set the value of `env.Foo` to the value of the environment variable `MY_VAR`. |
|||
|
|||
#### Other Notes On Map/Reflect |
|||
|
|||
Any embedded struct is treated as a section by default, and there is no automatic parent-child relations in map/reflect feature: |
|||
|
|||
```go |
|||
type Child struct { |
|||
Age string |
|||
} |
|||
|
|||
type Parent struct { |
|||
Name string |
|||
Child |
|||
} |
|||
|
|||
type Config struct { |
|||
City string |
|||
Parent |
|||
} |
|||
``` |
|||
|
|||
Example configuration: |
|||
|
|||
```ini |
|||
City = Boston |
|||
|
|||
[Parent] |
|||
Name = Unknwon |
|||
|
|||
[Child] |
|||
Age = 21 |
|||
``` |
|||
|
|||
What if, yes, I'm paranoid, I want embedded struct to be in the same section. Well, all roads lead to Rome. |
|||
|
|||
```go |
|||
type Child struct { |
|||
Age string |
|||
} |
|||
|
|||
type Parent struct { |
|||
Name string |
|||
Child `ini:"Parent"` |
|||
} |
|||
|
|||
type Config struct { |
|||
City string |
|||
Parent |
|||
} |
|||
``` |
|||
|
|||
Example configuration: |
|||
|
|||
```ini |
|||
City = Boston |
|||
|
|||
[Parent] |
|||
Name = Unknwon |
|||
Age = 21 |
|||
``` |
|||
|
|||
## Getting Help |
|||
|
|||
- [API Documentation](https://gowalker.org/gopkg.in/ini.v1) |
|||
- [File An Issue](https://github.com/go-ini/ini/issues/new) |
|||
|
|||
## FAQs |
|||
|
|||
### What does `BlockMode` field do? |
|||
|
|||
By default, library lets you read and write values so we need a locker to make sure your data is safe. But in cases that you are very sure about only reading data through the library, you can set `cfg.BlockMode = false` to speed up read operations about **50-70%** faster. |
|||
|
|||
### Why another INI library? |
|||
|
|||
Many people are using my another INI library [goconfig](https://github.com/Unknwon/goconfig), so the reason for this one is I would like to make more Go style code. Also when you set `cfg.BlockMode = false`, this one is about **10-30%** faster. |
|||
|
|||
To make those changes I have to confirm API broken, so it's safer to keep it in another place and start using `gopkg.in` to version my package at this time.(PS: shorter import path) |
|||
|
|||
## License |
|||
|
|||
This project is under Apache v2 License. See the [LICENSE](LICENSE) file for the full license text. |
@ -0,0 +1,690 @@ |
|||
本包提供了 Go 语言中读写 INI 文件的功能。 |
|||
|
|||
## 功能特性 |
|||
|
|||
- 支持覆盖加载多个数据源(`[]byte` 或文件) |
|||
- 支持递归读取键值 |
|||
- 支持读取父子分区 |
|||
- 支持读取自增键名 |
|||
- 支持读取多行的键值 |
|||
- 支持大量辅助方法 |
|||
- 支持在读取时直接转换为 Go 语言类型 |
|||
- 支持读取和 **写入** 分区和键的注释 |
|||
- 轻松操作分区、键值和注释 |
|||
- 在保存文件时分区和键值会保持原有的顺序 |
|||
|
|||
## 下载安装 |
|||
|
|||
使用一个特定版本: |
|||
|
|||
go get gopkg.in/ini.v1 |
|||
|
|||
使用最新版: |
|||
|
|||
go get github.com/go-ini/ini |
|||
|
|||
如需更新请添加 `-u` 选项。 |
|||
|
|||
### 测试安装 |
|||
|
|||
如果您想要在自己的机器上运行测试,请使用 `-t` 标记: |
|||
|
|||
go get -t gopkg.in/ini.v1 |
|||
|
|||
如需更新请添加 `-u` 选项。 |
|||
|
|||
## 开始使用 |
|||
|
|||
### 从数据源加载 |
|||
|
|||
一个 **数据源** 可以是 `[]byte` 类型的原始数据,或 `string` 类型的文件路径。您可以加载 **任意多个** 数据源。如果您传递其它类型的数据源,则会直接返回错误。 |
|||
|
|||
```go |
|||
cfg, err := ini.Load([]byte("raw data"), "filename") |
|||
``` |
|||
|
|||
或者从一个空白的文件开始: |
|||
|
|||
```go |
|||
cfg := ini.Empty() |
|||
``` |
|||
|
|||
当您在一开始无法决定需要加载哪些数据源时,仍可以使用 **Append()** 在需要的时候加载它们。 |
|||
|
|||
```go |
|||
err := cfg.Append("other file", []byte("other raw data")) |
|||
``` |
|||
|
|||
当您想要加载一系列文件,但是不能够确定其中哪些文件是不存在的,可以通过调用函数 `LooseLoad` 来忽略它们(`Load` 会因为文件不存在而返回错误): |
|||
|
|||
```go |
|||
cfg, err := ini.LooseLoad("filename", "filename_404") |
|||
``` |
|||
|
|||
更牛逼的是,当那些之前不存在的文件在重新调用 `Reload` 方法的时候突然出现了,那么它们会被正常加载。 |
|||
|
|||
#### 忽略键名的大小写 |
|||
|
|||
有时候分区和键的名称大小写混合非常烦人,这个时候就可以通过 `InsensitiveLoad` 将所有分区和键名在读取里强制转换为小写: |
|||
|
|||
```go |
|||
cfg, err := ini.InsensitiveLoad("filename") |
|||
//... |
|||
|
|||
// sec1 和 sec2 指向同一个分区对象 |
|||
sec1, err := cfg.GetSection("Section") |
|||
sec2, err := cfg.GetSection("SecTIOn") |
|||
|
|||
// key1 和 key2 指向同一个键对象 |
|||
key1, err := cfg.GetKey("Key") |
|||
key2, err := cfg.GetKey("KeY") |
|||
``` |
|||
|
|||
#### 类似 MySQL 配置中的布尔值键 |
|||
|
|||
MySQL 的配置文件中会出现没有具体值的布尔类型的键: |
|||
|
|||
```ini |
|||
[mysqld] |
|||
... |
|||
skip-host-cache |
|||
skip-name-resolve |
|||
``` |
|||
|
|||
默认情况下这被认为是缺失值而无法完成解析,但可以通过高级的加载选项对它们进行处理: |
|||
|
|||
```go |
|||
cfg, err := LoadSources(LoadOptions{AllowBooleanKeys: true}, "my.cnf")) |
|||
``` |
|||
|
|||
这些键的值永远为 `true`,且在保存到文件时也只会输出键名。 |
|||
|
|||
### 操作分区(Section) |
|||
|
|||
获取指定分区: |
|||
|
|||
```go |
|||
section, err := cfg.GetSection("section name") |
|||
``` |
|||
|
|||
如果您想要获取默认分区,则可以用空字符串代替分区名: |
|||
|
|||
```go |
|||
section, err := cfg.GetSection("") |
|||
``` |
|||
|
|||
当您非常确定某个分区是存在的,可以使用以下简便方法: |
|||
|
|||
```go |
|||
section := cfg.Section("") |
|||
``` |
|||
|
|||
如果不小心判断错了,要获取的分区其实是不存在的,那会发生什么呢?没事的,它会自动创建并返回一个对应的分区对象给您。 |
|||
|
|||
创建一个分区: |
|||
|
|||
```go |
|||
err := cfg.NewSection("new section") |
|||
``` |
|||
|
|||
获取所有分区对象或名称: |
|||
|
|||
```go |
|||
sections := cfg.Sections() |
|||
names := cfg.SectionStrings() |
|||
``` |
|||
|
|||
### 操作键(Key) |
|||
|
|||
获取某个分区下的键: |
|||
|
|||
```go |
|||
key, err := cfg.Section("").GetKey("key name") |
|||
``` |
|||
|
|||
和分区一样,您也可以直接获取键而忽略错误处理: |
|||
|
|||
```go |
|||
key := cfg.Section("").Key("key name") |
|||
``` |
|||
|
|||
判断某个键是否存在: |
|||
|
|||
```go |
|||
yes := cfg.Section("").HasKey("key name") |
|||
``` |
|||
|
|||
创建一个新的键: |
|||
|
|||
```go |
|||
err := cfg.Section("").NewKey("name", "value") |
|||
``` |
|||
|
|||
获取分区下的所有键或键名: |
|||
|
|||
```go |
|||
keys := cfg.Section("").Keys() |
|||
names := cfg.Section("").KeyStrings() |
|||
``` |
|||
|
|||
获取分区下的所有键值对的克隆: |
|||
|
|||
```go |
|||
hash := cfg.Section("").KeysHash() |
|||
``` |
|||
|
|||
### 操作键值(Value) |
|||
|
|||
获取一个类型为字符串(string)的值: |
|||
|
|||
```go |
|||
val := cfg.Section("").Key("key name").String() |
|||
``` |
|||
|
|||
获取值的同时通过自定义函数进行处理验证: |
|||
|
|||
```go |
|||
val := cfg.Section("").Key("key name").Validate(func(in string) string { |
|||
if len(in) == 0 { |
|||
return "default" |
|||
} |
|||
return in |
|||
}) |
|||
``` |
|||
|
|||
如果您不需要任何对值的自动转变功能(例如递归读取),可以直接获取原值(这种方式性能最佳): |
|||
|
|||
```go |
|||
val := cfg.Section("").Key("key name").Value() |
|||
``` |
|||
|
|||
判断某个原值是否存在: |
|||
|
|||
```go |
|||
yes := cfg.Section("").HasValue("test value") |
|||
``` |
|||
|
|||
获取其它类型的值: |
|||
|
|||
```go |
|||
// 布尔值的规则: |
|||
// true 当值为:1, t, T, TRUE, true, True, YES, yes, Yes, y, ON, on, On |
|||
// false 当值为:0, f, F, FALSE, false, False, NO, no, No, n, OFF, off, Off |
|||
v, err = cfg.Section("").Key("BOOL").Bool() |
|||
v, err = cfg.Section("").Key("FLOAT64").Float64() |
|||
v, err = cfg.Section("").Key("INT").Int() |
|||
v, err = cfg.Section("").Key("INT64").Int64() |
|||
v, err = cfg.Section("").Key("UINT").Uint() |
|||
v, err = cfg.Section("").Key("UINT64").Uint64() |
|||
v, err = cfg.Section("").Key("TIME").TimeFormat(time.RFC3339) |
|||
v, err = cfg.Section("").Key("TIME").Time() // RFC3339 |
|||
|
|||
v = cfg.Section("").Key("BOOL").MustBool() |
|||
v = cfg.Section("").Key("FLOAT64").MustFloat64() |
|||
v = cfg.Section("").Key("INT").MustInt() |
|||
v = cfg.Section("").Key("INT64").MustInt64() |
|||
v = cfg.Section("").Key("UINT").MustUint() |
|||
v = cfg.Section("").Key("UINT64").MustUint64() |
|||
v = cfg.Section("").Key("TIME").MustTimeFormat(time.RFC3339) |
|||
v = cfg.Section("").Key("TIME").MustTime() // RFC3339 |
|||
|
|||
// 由 Must 开头的方法名允许接收一个相同类型的参数来作为默认值, |
|||
// 当键不存在或者转换失败时,则会直接返回该默认值。 |
|||
// 但是,MustString 方法必须传递一个默认值。 |
|||
|
|||
v = cfg.Seciont("").Key("String").MustString("default") |
|||
v = cfg.Section("").Key("BOOL").MustBool(true) |
|||
v = cfg.Section("").Key("FLOAT64").MustFloat64(1.25) |
|||
v = cfg.Section("").Key("INT").MustInt(10) |
|||
v = cfg.Section("").Key("INT64").MustInt64(99) |
|||
v = cfg.Section("").Key("UINT").MustUint(3) |
|||
v = cfg.Section("").Key("UINT64").MustUint64(6) |
|||
v = cfg.Section("").Key("TIME").MustTimeFormat(time.RFC3339, time.Now()) |
|||
v = cfg.Section("").Key("TIME").MustTime(time.Now()) // RFC3339 |
|||
``` |
|||
|
|||
如果我的值有好多行怎么办? |
|||
|
|||
```ini |
|||
[advance] |
|||
ADDRESS = """404 road, |
|||
NotFound, State, 5000 |
|||
Earth""" |
|||
``` |
|||
|
|||
嗯哼?小 case! |
|||
|
|||
```go |
|||
cfg.Section("advance").Key("ADDRESS").String() |
|||
|
|||
/* --- start --- |
|||
404 road, |
|||
NotFound, State, 5000 |
|||
Earth |
|||
------ end --- */ |
|||
``` |
|||
|
|||
赞爆了!那要是我属于一行的内容写不下想要写到第二行怎么办? |
|||
|
|||
```ini |
|||
[advance] |
|||
two_lines = how about \ |
|||
continuation lines? |
|||
lots_of_lines = 1 \ |
|||
2 \ |
|||
3 \ |
|||
4 |
|||
``` |
|||
|
|||
简直是小菜一碟! |
|||
|
|||
```go |
|||
cfg.Section("advance").Key("two_lines").String() // how about continuation lines? |
|||
cfg.Section("advance").Key("lots_of_lines").String() // 1 2 3 4 |
|||
``` |
|||
|
|||
可是我有时候觉得两行连在一起特别没劲,怎么才能不自动连接两行呢? |
|||
|
|||
```go |
|||
cfg, err := ini.LoadSources(ini.LoadOptions{ |
|||
IgnoreContinuation: true, |
|||
}, "filename") |
|||
``` |
|||
|
|||
哇靠给力啊! |
|||
|
|||
需要注意的是,值两侧的单引号会被自动剔除: |
|||
|
|||
```ini |
|||
foo = "some value" // foo: some value |
|||
bar = 'some value' // bar: some value |
|||
``` |
|||
|
|||
这就是全部了?哈哈,当然不是。 |
|||
|
|||
#### 操作键值的辅助方法 |
|||
|
|||
获取键值时设定候选值: |
|||
|
|||
```go |
|||
v = cfg.Section("").Key("STRING").In("default", []string{"str", "arr", "types"}) |
|||
v = cfg.Section("").Key("FLOAT64").InFloat64(1.1, []float64{1.25, 2.5, 3.75}) |
|||
v = cfg.Section("").Key("INT").InInt(5, []int{10, 20, 30}) |
|||
v = cfg.Section("").Key("INT64").InInt64(10, []int64{10, 20, 30}) |
|||
v = cfg.Section("").Key("UINT").InUint(4, []int{3, 6, 9}) |
|||
v = cfg.Section("").Key("UINT64").InUint64(8, []int64{3, 6, 9}) |
|||
v = cfg.Section("").Key("TIME").InTimeFormat(time.RFC3339, time.Now(), []time.Time{time1, time2, time3}) |
|||
v = cfg.Section("").Key("TIME").InTime(time.Now(), []time.Time{time1, time2, time3}) // RFC3339 |
|||
``` |
|||
|
|||
如果获取到的值不是候选值的任意一个,则会返回默认值,而默认值不需要是候选值中的一员。 |
|||
|
|||
验证获取的值是否在指定范围内: |
|||
|
|||
```go |
|||
vals = cfg.Section("").Key("FLOAT64").RangeFloat64(0.0, 1.1, 2.2) |
|||
vals = cfg.Section("").Key("INT").RangeInt(0, 10, 20) |
|||
vals = cfg.Section("").Key("INT64").RangeInt64(0, 10, 20) |
|||
vals = cfg.Section("").Key("UINT").RangeUint(0, 3, 9) |
|||
vals = cfg.Section("").Key("UINT64").RangeUint64(0, 3, 9) |
|||
vals = cfg.Section("").Key("TIME").RangeTimeFormat(time.RFC3339, time.Now(), minTime, maxTime) |
|||
vals = cfg.Section("").Key("TIME").RangeTime(time.Now(), minTime, maxTime) // RFC3339 |
|||
``` |
|||
|
|||
##### 自动分割键值到切片(slice) |
|||
|
|||
当存在无效输入时,使用零值代替: |
|||
|
|||
```go |
|||
// Input: 1.1, 2.2, 3.3, 4.4 -> [1.1 2.2 3.3 4.4] |
|||
// Input: how, 2.2, are, you -> [0.0 2.2 0.0 0.0] |
|||
vals = cfg.Section("").Key("STRINGS").Strings(",") |
|||
vals = cfg.Section("").Key("FLOAT64S").Float64s(",") |
|||
vals = cfg.Section("").Key("INTS").Ints(",") |
|||
vals = cfg.Section("").Key("INT64S").Int64s(",") |
|||
vals = cfg.Section("").Key("UINTS").Uints(",") |
|||
vals = cfg.Section("").Key("UINT64S").Uint64s(",") |
|||
vals = cfg.Section("").Key("TIMES").Times(",") |
|||
``` |
|||
|
|||
从结果切片中剔除无效输入: |
|||
|
|||
```go |
|||
// Input: 1.1, 2.2, 3.3, 4.4 -> [1.1 2.2 3.3 4.4] |
|||
// Input: how, 2.2, are, you -> [2.2] |
|||
vals = cfg.Section("").Key("FLOAT64S").ValidFloat64s(",") |
|||
vals = cfg.Section("").Key("INTS").ValidInts(",") |
|||
vals = cfg.Section("").Key("INT64S").ValidInt64s(",") |
|||
vals = cfg.Section("").Key("UINTS").ValidUints(",") |
|||
vals = cfg.Section("").Key("UINT64S").ValidUint64s(",") |
|||
vals = cfg.Section("").Key("TIMES").ValidTimes(",") |
|||
``` |
|||
|
|||
当存在无效输入时,直接返回错误: |
|||
|
|||
```go |
|||
// Input: 1.1, 2.2, 3.3, 4.4 -> [1.1 2.2 3.3 4.4] |
|||
// Input: how, 2.2, are, you -> error |
|||
vals = cfg.Section("").Key("FLOAT64S").StrictFloat64s(",") |
|||
vals = cfg.Section("").Key("INTS").StrictInts(",") |
|||
vals = cfg.Section("").Key("INT64S").StrictInt64s(",") |
|||
vals = cfg.Section("").Key("UINTS").StrictUints(",") |
|||
vals = cfg.Section("").Key("UINT64S").StrictUint64s(",") |
|||
vals = cfg.Section("").Key("TIMES").StrictTimes(",") |
|||
``` |
|||
|
|||
### 保存配置 |
|||
|
|||
终于到了这个时刻,是时候保存一下配置了。 |
|||
|
|||
比较原始的做法是输出配置到某个文件: |
|||
|
|||
```go |
|||
// ... |
|||
err = cfg.SaveTo("my.ini") |
|||
err = cfg.SaveToIndent("my.ini", "\t") |
|||
``` |
|||
|
|||
另一个比较高级的做法是写入到任何实现 `io.Writer` 接口的对象中: |
|||
|
|||
```go |
|||
// ... |
|||
cfg.WriteTo(writer) |
|||
cfg.WriteToIndent(writer, "\t") |
|||
``` |
|||
|
|||
### 高级用法 |
|||
|
|||
#### 递归读取键值 |
|||
|
|||
在获取所有键值的过程中,特殊语法 `%(<name>)s` 会被应用,其中 `<name>` 可以是相同分区或者默认分区下的键名。字符串 `%(<name>)s` 会被相应的键值所替代,如果指定的键不存在,则会用空字符串替代。您可以最多使用 99 层的递归嵌套。 |
|||
|
|||
```ini |
|||
NAME = ini |
|||
|
|||
[author] |
|||
NAME = Unknwon |
|||
GITHUB = https://github.com/%(NAME)s |
|||
|
|||
[package] |
|||
FULL_NAME = github.com/go-ini/%(NAME)s |
|||
``` |
|||
|
|||
```go |
|||
cfg.Section("author").Key("GITHUB").String() // https://github.com/Unknwon |
|||
cfg.Section("package").Key("FULL_NAME").String() // github.com/go-ini/ini |
|||
``` |
|||
|
|||
#### 读取父子分区 |
|||
|
|||
您可以在分区名称中使用 `.` 来表示两个或多个分区之间的父子关系。如果某个键在子分区中不存在,则会去它的父分区中再次寻找,直到没有父分区为止。 |
|||
|
|||
```ini |
|||
NAME = ini |
|||
VERSION = v1 |
|||
IMPORT_PATH = gopkg.in/%(NAME)s.%(VERSION)s |
|||
|
|||
[package] |
|||
CLONE_URL = https://%(IMPORT_PATH)s |
|||
|
|||
[package.sub] |
|||
``` |
|||
|
|||
```go |
|||
cfg.Section("package.sub").Key("CLONE_URL").String() // https://gopkg.in/ini.v1 |
|||
``` |
|||
|
|||
#### 获取上级父分区下的所有键名 |
|||
|
|||
```go |
|||
cfg.Section("package.sub").ParentKeys() // ["CLONE_URL"] |
|||
``` |
|||
|
|||
#### 读取自增键名 |
|||
|
|||
如果数据源中的键名为 `-`,则认为该键使用了自增键名的特殊语法。计数器从 1 开始,并且分区之间是相互独立的。 |
|||
|
|||
```ini |
|||
[features] |
|||
-: Support read/write comments of keys and sections |
|||
-: Support auto-increment of key names |
|||
-: Support load multiple files to overwrite key values |
|||
``` |
|||
|
|||
```go |
|||
cfg.Section("features").KeyStrings() // []{"#1", "#2", "#3"} |
|||
``` |
|||
|
|||
### 映射到结构 |
|||
|
|||
想要使用更加面向对象的方式玩转 INI 吗?好主意。 |
|||
|
|||
```ini |
|||
Name = Unknwon |
|||
age = 21 |
|||
Male = true |
|||
Born = 1993-01-01T20:17:05Z |
|||
|
|||
[Note] |
|||
Content = Hi is a good man! |
|||
Cities = HangZhou, Boston |
|||
``` |
|||
|
|||
```go |
|||
type Note struct { |
|||
Content string |
|||
Cities []string |
|||
} |
|||
|
|||
type Person struct { |
|||
Name string |
|||
Age int `ini:"age"` |
|||
Male bool |
|||
Born time.Time |
|||
Note |
|||
Created time.Time `ini:"-"` |
|||
} |
|||
|
|||
func main() { |
|||
cfg, err := ini.Load("path/to/ini") |
|||
// ... |
|||
p := new(Person) |
|||
err = cfg.MapTo(p) |
|||
// ... |
|||
|
|||
// 一切竟可以如此的简单。 |
|||
err = ini.MapTo(p, "path/to/ini") |
|||
// ... |
|||
|
|||
// 嗯哼?只需要映射一个分区吗? |
|||
n := new(Note) |
|||
err = cfg.Section("Note").MapTo(n) |
|||
// ... |
|||
} |
|||
``` |
|||
|
|||
结构的字段怎么设置默认值呢?很简单,只要在映射之前对指定字段进行赋值就可以了。如果键未找到或者类型错误,该值不会发生改变。 |
|||
|
|||
```go |
|||
// ... |
|||
p := &Person{ |
|||
Name: "Joe", |
|||
} |
|||
// ... |
|||
``` |
|||
|
|||
这样玩 INI 真的好酷啊!然而,如果不能还给我原来的配置文件,有什么卵用? |
|||
|
|||
### 从结构反射 |
|||
|
|||
可是,我有说不能吗? |
|||
|
|||
```go |
|||
type Embeded struct { |
|||
Dates []time.Time `delim:"|"` |
|||
Places []string `ini:"places,omitempty"` |
|||
None []int `ini:",omitempty"` |
|||
} |
|||
|
|||
type Author struct { |
|||
Name string `ini:"NAME"` |
|||
Male bool |
|||
Age int |
|||
GPA float64 |
|||
NeverMind string `ini:"-"` |
|||
*Embeded |
|||
} |
|||
|
|||
func main() { |
|||
a := &Author{"Unknwon", true, 21, 2.8, "", |
|||
&Embeded{ |
|||
[]time.Time{time.Now(), time.Now()}, |
|||
[]string{"HangZhou", "Boston"}, |
|||
[]int{}, |
|||
}} |
|||
cfg := ini.Empty() |
|||
err = ini.ReflectFrom(cfg, a) |
|||
// ... |
|||
} |
|||
``` |
|||
|
|||
瞧瞧,奇迹发生了。 |
|||
|
|||
```ini |
|||
NAME = Unknwon |
|||
Male = true |
|||
Age = 21 |
|||
GPA = 2.8 |
|||
|
|||
[Embeded] |
|||
Dates = 2015-08-07T22:14:22+08:00|2015-08-07T22:14:22+08:00 |
|||
places = HangZhou,Boston |
|||
``` |
|||
|
|||
#### 名称映射器(Name Mapper) |
|||
|
|||
为了节省您的时间并简化代码,本库支持类型为 [`NameMapper`](https://gowalker.org/gopkg.in/ini.v1#NameMapper) 的名称映射器,该映射器负责结构字段名与分区名和键名之间的映射。 |
|||
|
|||
目前有 2 款内置的映射器: |
|||
|
|||
- `AllCapsUnderscore`:该映射器将字段名转换至格式 `ALL_CAPS_UNDERSCORE` 后再去匹配分区名和键名。 |
|||
- `TitleUnderscore`:该映射器将字段名转换至格式 `title_underscore` 后再去匹配分区名和键名。 |
|||
|
|||
使用方法: |
|||
|
|||
```go |
|||
type Info struct{ |
|||
PackageName string |
|||
} |
|||
|
|||
func main() { |
|||
err = ini.MapToWithMapper(&Info{}, ini.TitleUnderscore, []byte("package_name=ini")) |
|||
// ... |
|||
|
|||
cfg, err := ini.Load([]byte("PACKAGE_NAME=ini")) |
|||
// ... |
|||
info := new(Info) |
|||
cfg.NameMapper = ini.AllCapsUnderscore |
|||
err = cfg.MapTo(info) |
|||
// ... |
|||
} |
|||
``` |
|||
|
|||
使用函数 `ini.ReflectFromWithMapper` 时也可应用相同的规则。 |
|||
|
|||
#### 值映射器(Value Mapper) |
|||
|
|||
值映射器允许使用一个自定义函数自动展开值的具体内容,例如:运行时获取环境变量: |
|||
|
|||
```go |
|||
type Env struct { |
|||
Foo string `ini:"foo"` |
|||
} |
|||
|
|||
func main() { |
|||
cfg, err := ini.Load([]byte("[env]\nfoo = ${MY_VAR}\n") |
|||
cfg.ValueMapper = os.ExpandEnv |
|||
// ... |
|||
env := &Env{} |
|||
err = cfg.Section("env").MapTo(env) |
|||
} |
|||
``` |
|||
|
|||
本例中,`env.Foo` 将会是运行时所获取到环境变量 `MY_VAR` 的值。 |
|||
|
|||
#### 映射/反射的其它说明 |
|||
|
|||
任何嵌入的结构都会被默认认作一个不同的分区,并且不会自动产生所谓的父子分区关联: |
|||
|
|||
```go |
|||
type Child struct { |
|||
Age string |
|||
} |
|||
|
|||
type Parent struct { |
|||
Name string |
|||
Child |
|||
} |
|||
|
|||
type Config struct { |
|||
City string |
|||
Parent |
|||
} |
|||
``` |
|||
|
|||
示例配置文件: |
|||
|
|||
```ini |
|||
City = Boston |
|||
|
|||
[Parent] |
|||
Name = Unknwon |
|||
|
|||
[Child] |
|||
Age = 21 |
|||
``` |
|||
|
|||
很好,但是,我就是要嵌入结构也在同一个分区。好吧,你爹是李刚! |
|||
|
|||
```go |
|||
type Child struct { |
|||
Age string |
|||
} |
|||
|
|||
type Parent struct { |
|||
Name string |
|||
Child `ini:"Parent"` |
|||
} |
|||
|
|||
type Config struct { |
|||
City string |
|||
Parent |
|||
} |
|||
``` |
|||
|
|||
示例配置文件: |
|||
|
|||
```ini |
|||
City = Boston |
|||
|
|||
[Parent] |
|||
Name = Unknwon |
|||
Age = 21 |
|||
``` |
|||
|
|||
## 获取帮助 |
|||
|
|||
- [API 文档](https://gowalker.org/gopkg.in/ini.v1) |
|||
- [创建工单](https://github.com/go-ini/ini/issues/new) |
|||
|
|||
## 常见问题 |
|||
|
|||
### 字段 `BlockMode` 是什么? |
|||
|
|||
默认情况下,本库会在您进行读写操作时采用锁机制来确保数据时间。但在某些情况下,您非常确定只进行读操作。此时,您可以通过设置 `cfg.BlockMode = false` 来将读操作提升大约 **50-70%** 的性能。 |
|||
|
|||
### 为什么要写另一个 INI 解析库? |
|||
|
|||
许多人都在使用我的 [goconfig](https://github.com/Unknwon/goconfig) 来完成对 INI 文件的操作,但我希望使用更加 Go 风格的代码。并且当您设置 `cfg.BlockMode = false` 时,会有大约 **10-30%** 的性能提升。 |
|||
|
|||
为了做出这些改变,我必须对 API 进行破坏,所以新开一个仓库是最安全的做法。除此之外,本库直接使用 `gopkg.in` 来进行版本化发布。(其实真相是导入路径更短了) |
@ -0,0 +1,32 @@ |
|||
// Copyright 2016 Unknwon
|
|||
//
|
|||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
|||
// not use this file except in compliance with the License. You may obtain
|
|||
// a copy of the License at
|
|||
//
|
|||
// http://www.apache.org/licenses/LICENSE-2.0
|
|||
//
|
|||
// Unless required by applicable law or agreed to in writing, software
|
|||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|||
// License for the specific language governing permissions and limitations
|
|||
// under the License.
|
|||
|
|||
package ini |
|||
|
|||
import ( |
|||
"fmt" |
|||
) |
|||
|
|||
type ErrDelimiterNotFound struct { |
|||
Line string |
|||
} |
|||
|
|||
func IsErrDelimiterNotFound(err error) bool { |
|||
_, ok := err.(ErrDelimiterNotFound) |
|||
return ok |
|||
} |
|||
|
|||
func (err ErrDelimiterNotFound) Error() string { |
|||
return fmt.Sprintf("key-value delimiter not found: %s", err.Line) |
|||
} |
@ -0,0 +1,501 @@ |
|||
// Copyright 2014 Unknwon
|
|||
//
|
|||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
|||
// not use this file except in compliance with the License. You may obtain
|
|||
// a copy of the License at
|
|||
//
|
|||
// http://www.apache.org/licenses/LICENSE-2.0
|
|||
//
|
|||
// Unless required by applicable law or agreed to in writing, software
|
|||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|||
// License for the specific language governing permissions and limitations
|
|||
// under the License.
|
|||
|
|||
// Package ini provides INI file read and write functionality in Go.
|
|||
package ini |
|||
|
|||
import ( |
|||
"bytes" |
|||
"errors" |
|||
"fmt" |
|||
"io" |
|||
"os" |
|||
"regexp" |
|||
"runtime" |
|||
"strconv" |
|||
"strings" |
|||
"sync" |
|||
"time" |
|||
) |
|||
|
|||
const ( |
|||
// Name for default section. You can use this constant or the string literal.
|
|||
// In most of cases, an empty string is all you need to access the section.
|
|||
DEFAULT_SECTION = "DEFAULT" |
|||
|
|||
// Maximum allowed depth when recursively substituing variable names.
|
|||
_DEPTH_VALUES = 99 |
|||
_VERSION = "1.21.1" |
|||
) |
|||
|
|||
// Version returns current package version literal.
|
|||
func Version() string { |
|||
return _VERSION |
|||
} |
|||
|
|||
var ( |
|||
// Delimiter to determine or compose a new line.
|
|||
// This variable will be changed to "\r\n" automatically on Windows
|
|||
// at package init time.
|
|||
LineBreak = "\n" |
|||
|
|||
// Variable regexp pattern: %(variable)s
|
|||
varPattern = regexp.MustCompile(`%\(([^\)]+)\)s`) |
|||
|
|||
// Indicate whether to align "=" sign with spaces to produce pretty output
|
|||
// or reduce all possible spaces for compact format.
|
|||
PrettyFormat = true |
|||
|
|||
// Explicitly write DEFAULT section header
|
|||
DefaultHeader = false |
|||
) |
|||
|
|||
func init() { |
|||
if runtime.GOOS == "windows" { |
|||
LineBreak = "\r\n" |
|||
} |
|||
} |
|||
|
|||
func inSlice(str string, s []string) bool { |
|||
for _, v := range s { |
|||
if str == v { |
|||
return true |
|||
} |
|||
} |
|||
return false |
|||
} |
|||
|
|||
// dataSource is an interface that returns object which can be read and closed.
|
|||
type dataSource interface { |
|||
ReadCloser() (io.ReadCloser, error) |
|||
} |
|||
|
|||
// sourceFile represents an object that contains content on the local file system.
|
|||
type sourceFile struct { |
|||
name string |
|||
} |
|||
|
|||
func (s sourceFile) ReadCloser() (_ io.ReadCloser, err error) { |
|||
return os.Open(s.name) |
|||
} |
|||
|
|||
type bytesReadCloser struct { |
|||
reader io.Reader |
|||
} |
|||
|
|||
func (rc *bytesReadCloser) Read(p []byte) (n int, err error) { |
|||
return rc.reader.Read(p) |
|||
} |
|||
|
|||
func (rc *bytesReadCloser) Close() error { |
|||
return nil |
|||
} |
|||
|
|||
// sourceData represents an object that contains content in memory.
|
|||
type sourceData struct { |
|||
data []byte |
|||
} |
|||
|
|||
func (s *sourceData) ReadCloser() (io.ReadCloser, error) { |
|||
return &bytesReadCloser{bytes.NewReader(s.data)}, nil |
|||
} |
|||
|
|||
// File represents a combination of a or more INI file(s) in memory.
|
|||
type File struct { |
|||
// Should make things safe, but sometimes doesn't matter.
|
|||
BlockMode bool |
|||
// Make sure data is safe in multiple goroutines.
|
|||
lock sync.RWMutex |
|||
|
|||
// Allow combination of multiple data sources.
|
|||
dataSources []dataSource |
|||
// Actual data is stored here.
|
|||
sections map[string]*Section |
|||
|
|||
// To keep data in order.
|
|||
sectionList []string |
|||
|
|||
options LoadOptions |
|||
|
|||
NameMapper |
|||
ValueMapper |
|||
} |
|||
|
|||
// newFile initializes File object with given data sources.
|
|||
func newFile(dataSources []dataSource, opts LoadOptions) *File { |
|||
return &File{ |
|||
BlockMode: true, |
|||
dataSources: dataSources, |
|||
sections: make(map[string]*Section), |
|||
sectionList: make([]string, 0, 10), |
|||
options: opts, |
|||
} |
|||
} |
|||
|
|||
func parseDataSource(source interface{}) (dataSource, error) { |
|||
switch s := source.(type) { |
|||
case string: |
|||
return sourceFile{s}, nil |
|||
case []byte: |
|||
return &sourceData{s}, nil |
|||
default: |
|||
return nil, fmt.Errorf("error parsing data source: unknown type '%s'", s) |
|||
} |
|||
} |
|||
|
|||
type LoadOptions struct { |
|||
// Loose indicates whether the parser should ignore nonexistent files or return error.
|
|||
Loose bool |
|||
// Insensitive indicates whether the parser forces all section and key names to lowercase.
|
|||
Insensitive bool |
|||
// IgnoreContinuation indicates whether to ignore continuation lines while parsing.
|
|||
IgnoreContinuation bool |
|||
// AllowBooleanKeys indicates whether to allow boolean type keys or treat as value is missing.
|
|||
// This type of keys are mostly used in my.cnf.
|
|||
AllowBooleanKeys bool |
|||
} |
|||
|
|||
func LoadSources(opts LoadOptions, source interface{}, others ...interface{}) (_ *File, err error) { |
|||
sources := make([]dataSource, len(others)+1) |
|||
sources[0], err = parseDataSource(source) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
for i := range others { |
|||
sources[i+1], err = parseDataSource(others[i]) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
} |
|||
f := newFile(sources, opts) |
|||
if err = f.Reload(); err != nil { |
|||
return nil, err |
|||
} |
|||
return f, nil |
|||
} |
|||
|
|||
// Load loads and parses from INI data sources.
|
|||
// Arguments can be mixed of file name with string type, or raw data in []byte.
|
|||
// It will return error if list contains nonexistent files.
|
|||
func Load(source interface{}, others ...interface{}) (*File, error) { |
|||
return LoadSources(LoadOptions{}, source, others...) |
|||
} |
|||
|
|||
// LooseLoad has exactly same functionality as Load function
|
|||
// except it ignores nonexistent files instead of returning error.
|
|||
func LooseLoad(source interface{}, others ...interface{}) (*File, error) { |
|||
return LoadSources(LoadOptions{Loose: true}, source, others...) |
|||
} |
|||
|
|||
// InsensitiveLoad has exactly same functionality as Load function
|
|||
// except it forces all section and key names to be lowercased.
|
|||
func InsensitiveLoad(source interface{}, others ...interface{}) (*File, error) { |
|||
return LoadSources(LoadOptions{Insensitive: true}, source, others...) |
|||
} |
|||
|
|||
// Empty returns an empty file object.
|
|||
func Empty() *File { |
|||
// Ignore error here, we sure our data is good.
|
|||
f, _ := Load([]byte("")) |
|||
return f |
|||
} |
|||
|
|||
// NewSection creates a new section.
|
|||
func (f *File) NewSection(name string) (*Section, error) { |
|||
if len(name) == 0 { |
|||
return nil, errors.New("error creating new section: empty section name") |
|||
} else if f.options.Insensitive && name != DEFAULT_SECTION { |
|||
name = strings.ToLower(name) |
|||
} |
|||
|
|||
if f.BlockMode { |
|||
f.lock.Lock() |
|||
defer f.lock.Unlock() |
|||
} |
|||
|
|||
if inSlice(name, f.sectionList) { |
|||
return f.sections[name], nil |
|||
} |
|||
|
|||
f.sectionList = append(f.sectionList, name) |
|||
f.sections[name] = newSection(f, name) |
|||
return f.sections[name], nil |
|||
} |
|||
|
|||
// NewSections creates a list of sections.
|
|||
func (f *File) NewSections(names ...string) (err error) { |
|||
for _, name := range names { |
|||
if _, err = f.NewSection(name); err != nil { |
|||
return err |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// GetSection returns section by given name.
|
|||
func (f *File) GetSection(name string) (*Section, error) { |
|||
if len(name) == 0 { |
|||
name = DEFAULT_SECTION |
|||
} else if f.options.Insensitive { |
|||
name = strings.ToLower(name) |
|||
} |
|||
|
|||
if f.BlockMode { |
|||
f.lock.RLock() |
|||
defer f.lock.RUnlock() |
|||
} |
|||
|
|||
sec := f.sections[name] |
|||
if sec == nil { |
|||
return nil, fmt.Errorf("section '%s' does not exist", name) |
|||
} |
|||
return sec, nil |
|||
} |
|||
|
|||
// Section assumes named section exists and returns a zero-value when not.
|
|||
func (f *File) Section(name string) *Section { |
|||
sec, err := f.GetSection(name) |
|||
if err != nil { |
|||
// Note: It's OK here because the only possible error is empty section name,
|
|||
// but if it's empty, this piece of code won't be executed.
|
|||
sec, _ = f.NewSection(name) |
|||
return sec |
|||
} |
|||
return sec |
|||
} |
|||
|
|||
// Section returns list of Section.
|
|||
func (f *File) Sections() []*Section { |
|||
sections := make([]*Section, len(f.sectionList)) |
|||
for i := range f.sectionList { |
|||
sections[i] = f.Section(f.sectionList[i]) |
|||
} |
|||
return sections |
|||
} |
|||
|
|||
// SectionStrings returns list of section names.
|
|||
func (f *File) SectionStrings() []string { |
|||
list := make([]string, len(f.sectionList)) |
|||
copy(list, f.sectionList) |
|||
return list |
|||
} |
|||
|
|||
// DeleteSection deletes a section.
|
|||
func (f *File) DeleteSection(name string) { |
|||
if f.BlockMode { |
|||
f.lock.Lock() |
|||
defer f.lock.Unlock() |
|||
} |
|||
|
|||
if len(name) == 0 { |
|||
name = DEFAULT_SECTION |
|||
} |
|||
|
|||
for i, s := range f.sectionList { |
|||
if s == name { |
|||
f.sectionList = append(f.sectionList[:i], f.sectionList[i+1:]...) |
|||
delete(f.sections, name) |
|||
return |
|||
} |
|||
} |
|||
} |
|||
|
|||
func (f *File) reload(s dataSource) error { |
|||
r, err := s.ReadCloser() |
|||
if err != nil { |
|||
return err |
|||
} |
|||
defer r.Close() |
|||
|
|||
return f.parse(r) |
|||
} |
|||
|
|||
// Reload reloads and parses all data sources.
|
|||
func (f *File) Reload() (err error) { |
|||
for _, s := range f.dataSources { |
|||
if err = f.reload(s); err != nil { |
|||
// In loose mode, we create an empty default section for nonexistent files.
|
|||
if os.IsNotExist(err) && f.options.Loose { |
|||
f.parse(bytes.NewBuffer(nil)) |
|||
continue |
|||
} |
|||
return err |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// Append appends one or more data sources and reloads automatically.
|
|||
func (f *File) Append(source interface{}, others ...interface{}) error { |
|||
ds, err := parseDataSource(source) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
f.dataSources = append(f.dataSources, ds) |
|||
for _, s := range others { |
|||
ds, err = parseDataSource(s) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
f.dataSources = append(f.dataSources, ds) |
|||
} |
|||
return f.Reload() |
|||
} |
|||
|
|||
// WriteToIndent writes content into io.Writer with given indention.
|
|||
// If PrettyFormat has been set to be true,
|
|||
// it will align "=" sign with spaces under each section.
|
|||
func (f *File) WriteToIndent(w io.Writer, indent string) (n int64, err error) { |
|||
equalSign := "=" |
|||
if PrettyFormat { |
|||
equalSign = " = " |
|||
} |
|||
|
|||
// Use buffer to make sure target is safe until finish encoding.
|
|||
buf := bytes.NewBuffer(nil) |
|||
for i, sname := range f.sectionList { |
|||
sec := f.Section(sname) |
|||
if len(sec.Comment) > 0 { |
|||
if sec.Comment[0] != '#' && sec.Comment[0] != ';' { |
|||
sec.Comment = "; " + sec.Comment |
|||
} |
|||
if _, err = buf.WriteString(sec.Comment + LineBreak); err != nil { |
|||
return 0, err |
|||
} |
|||
} |
|||
|
|||
if i > 0 || DefaultHeader { |
|||
if _, err = buf.WriteString("[" + sname + "]" + LineBreak); err != nil { |
|||
return 0, err |
|||
} |
|||
} else { |
|||
// Write nothing if default section is empty
|
|||
if len(sec.keyList) == 0 { |
|||
continue |
|||
} |
|||
} |
|||
|
|||
// Count and generate alignment length and buffer spaces using the
|
|||
// longest key. Keys may be modifed if they contain certain characters so
|
|||
// we need to take that into account in our calculation.
|
|||
alignLength := 0 |
|||
if PrettyFormat { |
|||
for _, kname := range sec.keyList { |
|||
keyLength := len(kname) |
|||
// First case will surround key by ` and second by """
|
|||
if strings.ContainsAny(kname, "\"=:") { |
|||
keyLength += 2 |
|||
} else if strings.Contains(kname, "`") { |
|||
keyLength += 6 |
|||
} |
|||
|
|||
if keyLength > alignLength { |
|||
alignLength = keyLength |
|||
} |
|||
} |
|||
} |
|||
alignSpaces := bytes.Repeat([]byte(" "), alignLength) |
|||
|
|||
for _, kname := range sec.keyList { |
|||
key := sec.Key(kname) |
|||
if len(key.Comment) > 0 { |
|||
if len(indent) > 0 && sname != DEFAULT_SECTION { |
|||
buf.WriteString(indent) |
|||
} |
|||
if key.Comment[0] != '#' && key.Comment[0] != ';' { |
|||
key.Comment = "; " + key.Comment |
|||
} |
|||
if _, err = buf.WriteString(key.Comment + LineBreak); err != nil { |
|||
return 0, err |
|||
} |
|||
} |
|||
|
|||
if len(indent) > 0 && sname != DEFAULT_SECTION { |
|||
buf.WriteString(indent) |
|||
} |
|||
|
|||
switch { |
|||
case key.isAutoIncrement: |
|||
kname = "-" |
|||
case strings.ContainsAny(kname, "\"=:"): |
|||
kname = "`" + kname + "`" |
|||
case strings.Contains(kname, "`"): |
|||
kname = `"""` + kname + `"""` |
|||
} |
|||
if _, err = buf.WriteString(kname); err != nil { |
|||
return 0, err |
|||
} |
|||
|
|||
if key.isBooleanType { |
|||
continue |
|||
} |
|||
|
|||
// Write out alignment spaces before "=" sign
|
|||
if PrettyFormat { |
|||
buf.Write(alignSpaces[:alignLength-len(kname)]) |
|||
} |
|||
|
|||
val := key.value |
|||
// In case key value contains "\n", "`", "\"", "#" or ";"
|
|||
if strings.ContainsAny(val, "\n`") { |
|||
val = `"""` + val + `"""` |
|||
} else if strings.ContainsAny(val, "#;") { |
|||
val = "`" + val + "`" |
|||
} |
|||
if _, err = buf.WriteString(equalSign + val + LineBreak); err != nil { |
|||
return 0, err |
|||
} |
|||
} |
|||
|
|||
// Put a line between sections
|
|||
if _, err = buf.WriteString(LineBreak); err != nil { |
|||
return 0, err |
|||
} |
|||
} |
|||
|
|||
return buf.WriteTo(w) |
|||
} |
|||
|
|||
// WriteTo writes file content into io.Writer.
|
|||
func (f *File) WriteTo(w io.Writer) (int64, error) { |
|||
return f.WriteToIndent(w, "") |
|||
} |
|||
|
|||
// SaveToIndent writes content to file system with given value indention.
|
|||
func (f *File) SaveToIndent(filename, indent string) error { |
|||
// Note: Because we are truncating with os.Create,
|
|||
// so it's safer to save to a temporary file location and rename afte done.
|
|||
tmpPath := filename + "." + strconv.Itoa(time.Now().Nanosecond()) + ".tmp" |
|||
defer os.Remove(tmpPath) |
|||
|
|||
fw, err := os.Create(tmpPath) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
if _, err = f.WriteToIndent(fw, indent); err != nil { |
|||
fw.Close() |
|||
return err |
|||
} |
|||
fw.Close() |
|||
|
|||
// Remove old file and rename the new one.
|
|||
os.Remove(filename) |
|||
return os.Rename(tmpPath, filename) |
|||
} |
|||
|
|||
// SaveTo writes content to file system.
|
|||
func (f *File) SaveTo(filename string) error { |
|||
return f.SaveToIndent(filename, "") |
|||
} |
@ -0,0 +1,633 @@ |
|||
// Copyright 2014 Unknwon
|
|||
//
|
|||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
|||
// not use this file except in compliance with the License. You may obtain
|
|||
// a copy of the License at
|
|||
//
|
|||
// http://www.apache.org/licenses/LICENSE-2.0
|
|||
//
|
|||
// Unless required by applicable law or agreed to in writing, software
|
|||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|||
// License for the specific language governing permissions and limitations
|
|||
// under the License.
|
|||
|
|||
package ini |
|||
|
|||
import ( |
|||
"fmt" |
|||
"strconv" |
|||
"strings" |
|||
"time" |
|||
) |
|||
|
|||
// Key represents a key under a section.
|
|||
type Key struct { |
|||
s *Section |
|||
name string |
|||
value string |
|||
isAutoIncrement bool |
|||
isBooleanType bool |
|||
|
|||
Comment string |
|||
} |
|||
|
|||
// ValueMapper represents a mapping function for values, e.g. os.ExpandEnv
|
|||
type ValueMapper func(string) string |
|||
|
|||
// Name returns name of key.
|
|||
func (k *Key) Name() string { |
|||
return k.name |
|||
} |
|||
|
|||
// Value returns raw value of key for performance purpose.
|
|||
func (k *Key) Value() string { |
|||
return k.value |
|||
} |
|||
|
|||
// String returns string representation of value.
|
|||
func (k *Key) String() string { |
|||
val := k.value |
|||
if k.s.f.ValueMapper != nil { |
|||
val = k.s.f.ValueMapper(val) |
|||
} |
|||
if strings.Index(val, "%") == -1 { |
|||
return val |
|||
} |
|||
|
|||
for i := 0; i < _DEPTH_VALUES; i++ { |
|||
vr := varPattern.FindString(val) |
|||
if len(vr) == 0 { |
|||
break |
|||
} |
|||
|
|||
// Take off leading '%(' and trailing ')s'.
|
|||
noption := strings.TrimLeft(vr, "%(") |
|||
noption = strings.TrimRight(noption, ")s") |
|||
|
|||
// Search in the same section.
|
|||
nk, err := k.s.GetKey(noption) |
|||
if err != nil { |
|||
// Search again in default section.
|
|||
nk, _ = k.s.f.Section("").GetKey(noption) |
|||
} |
|||
|
|||
// Substitute by new value and take off leading '%(' and trailing ')s'.
|
|||
val = strings.Replace(val, vr, nk.value, -1) |
|||
} |
|||
return val |
|||
} |
|||
|
|||
// Validate accepts a validate function which can
|
|||
// return modifed result as key value.
|
|||
func (k *Key) Validate(fn func(string) string) string { |
|||
return fn(k.String()) |
|||
} |
|||
|
|||
// parseBool returns the boolean value represented by the string.
|
|||
//
|
|||
// It accepts 1, t, T, TRUE, true, True, YES, yes, Yes, y, ON, on, On,
|
|||
// 0, f, F, FALSE, false, False, NO, no, No, n, OFF, off, Off.
|
|||
// Any other value returns an error.
|
|||
func parseBool(str string) (value bool, err error) { |
|||
switch str { |
|||
case "1", "t", "T", "true", "TRUE", "True", "YES", "yes", "Yes", "y", "ON", "on", "On": |
|||
return true, nil |
|||
case "0", "f", "F", "false", "FALSE", "False", "NO", "no", "No", "n", "OFF", "off", "Off": |
|||
return false, nil |
|||
} |
|||
return false, fmt.Errorf("parsing \"%s\": invalid syntax", str) |
|||
} |
|||
|
|||
// Bool returns bool type value.
|
|||
func (k *Key) Bool() (bool, error) { |
|||
return parseBool(k.String()) |
|||
} |
|||
|
|||
// Float64 returns float64 type value.
|
|||
func (k *Key) Float64() (float64, error) { |
|||
return strconv.ParseFloat(k.String(), 64) |
|||
} |
|||
|
|||
// Int returns int type value.
|
|||
func (k *Key) Int() (int, error) { |
|||
return strconv.Atoi(k.String()) |
|||
} |
|||
|
|||
// Int64 returns int64 type value.
|
|||
func (k *Key) Int64() (int64, error) { |
|||
return strconv.ParseInt(k.String(), 10, 64) |
|||
} |
|||
|
|||
// Uint returns uint type valued.
|
|||
func (k *Key) Uint() (uint, error) { |
|||
u, e := strconv.ParseUint(k.String(), 10, 64) |
|||
return uint(u), e |
|||
} |
|||
|
|||
// Uint64 returns uint64 type value.
|
|||
func (k *Key) Uint64() (uint64, error) { |
|||
return strconv.ParseUint(k.String(), 10, 64) |
|||
} |
|||
|
|||
// Duration returns time.Duration type value.
|
|||
func (k *Key) Duration() (time.Duration, error) { |
|||
return time.ParseDuration(k.String()) |
|||
} |
|||
|
|||
// TimeFormat parses with given format and returns time.Time type value.
|
|||
func (k *Key) TimeFormat(format string) (time.Time, error) { |
|||
return time.Parse(format, k.String()) |
|||
} |
|||
|
|||
// Time parses with RFC3339 format and returns time.Time type value.
|
|||
func (k *Key) Time() (time.Time, error) { |
|||
return k.TimeFormat(time.RFC3339) |
|||
} |
|||
|
|||
// MustString returns default value if key value is empty.
|
|||
func (k *Key) MustString(defaultVal string) string { |
|||
val := k.String() |
|||
if len(val) == 0 { |
|||
k.value = defaultVal |
|||
return defaultVal |
|||
} |
|||
return val |
|||
} |
|||
|
|||
// MustBool always returns value without error,
|
|||
// it returns false if error occurs.
|
|||
func (k *Key) MustBool(defaultVal ...bool) bool { |
|||
val, err := k.Bool() |
|||
if len(defaultVal) > 0 && err != nil { |
|||
k.value = strconv.FormatBool(defaultVal[0]) |
|||
return defaultVal[0] |
|||
} |
|||
return val |
|||
} |
|||
|
|||
// MustFloat64 always returns value without error,
|
|||
// it returns 0.0 if error occurs.
|
|||
func (k *Key) MustFloat64(defaultVal ...float64) float64 { |
|||
val, err := k.Float64() |
|||
if len(defaultVal) > 0 && err != nil { |
|||
k.value = strconv.FormatFloat(defaultVal[0], 'f', -1, 64) |
|||
return defaultVal[0] |
|||
} |
|||
return val |
|||
} |
|||
|
|||
// MustInt always returns value without error,
|
|||
// it returns 0 if error occurs.
|
|||
func (k *Key) MustInt(defaultVal ...int) int { |
|||
val, err := k.Int() |
|||
if len(defaultVal) > 0 && err != nil { |
|||
k.value = strconv.FormatInt(int64(defaultVal[0]), 10) |
|||
return defaultVal[0] |
|||
} |
|||
return val |
|||
} |
|||
|
|||
// MustInt64 always returns value without error,
|
|||
// it returns 0 if error occurs.
|
|||
func (k *Key) MustInt64(defaultVal ...int64) int64 { |
|||
val, err := k.Int64() |
|||
if len(defaultVal) > 0 && err != nil { |
|||
k.value = strconv.FormatInt(defaultVal[0], 10) |
|||
return defaultVal[0] |
|||
} |
|||
return val |
|||
} |
|||
|
|||
// MustUint always returns value without error,
|
|||
// it returns 0 if error occurs.
|
|||
func (k *Key) MustUint(defaultVal ...uint) uint { |
|||
val, err := k.Uint() |
|||
if len(defaultVal) > 0 && err != nil { |
|||
k.value = strconv.FormatUint(uint64(defaultVal[0]), 10) |
|||
return defaultVal[0] |
|||
} |
|||
return val |
|||
} |
|||
|
|||
// MustUint64 always returns value without error,
|
|||
// it returns 0 if error occurs.
|
|||
func (k *Key) MustUint64(defaultVal ...uint64) uint64 { |
|||
val, err := k.Uint64() |
|||
if len(defaultVal) > 0 && err != nil { |
|||
k.value = strconv.FormatUint(defaultVal[0], 10) |
|||
return defaultVal[0] |
|||
} |
|||
return val |
|||
} |
|||
|
|||
// MustDuration always returns value without error,
|
|||
// it returns zero value if error occurs.
|
|||
func (k *Key) MustDuration(defaultVal ...time.Duration) time.Duration { |
|||
val, err := k.Duration() |
|||
if len(defaultVal) > 0 && err != nil { |
|||
k.value = defaultVal[0].String() |
|||
return defaultVal[0] |
|||
} |
|||
return val |
|||
} |
|||
|
|||
// MustTimeFormat always parses with given format and returns value without error,
|
|||
// it returns zero value if error occurs.
|
|||
func (k *Key) MustTimeFormat(format string, defaultVal ...time.Time) time.Time { |
|||
val, err := k.TimeFormat(format) |
|||
if len(defaultVal) > 0 && err != nil { |
|||
k.value = defaultVal[0].Format(format) |
|||
return defaultVal[0] |
|||
} |
|||
return val |
|||
} |
|||
|
|||
// MustTime always parses with RFC3339 format and returns value without error,
|
|||
// it returns zero value if error occurs.
|
|||
func (k *Key) MustTime(defaultVal ...time.Time) time.Time { |
|||
return k.MustTimeFormat(time.RFC3339, defaultVal...) |
|||
} |
|||
|
|||
// In always returns value without error,
|
|||
// it returns default value if error occurs or doesn't fit into candidates.
|
|||
func (k *Key) In(defaultVal string, candidates []string) string { |
|||
val := k.String() |
|||
for _, cand := range candidates { |
|||
if val == cand { |
|||
return val |
|||
} |
|||
} |
|||
return defaultVal |
|||
} |
|||
|
|||
// InFloat64 always returns value without error,
|
|||
// it returns default value if error occurs or doesn't fit into candidates.
|
|||
func (k *Key) InFloat64(defaultVal float64, candidates []float64) float64 { |
|||
val := k.MustFloat64() |
|||
for _, cand := range candidates { |
|||
if val == cand { |
|||
return val |
|||
} |
|||
} |
|||
return defaultVal |
|||
} |
|||
|
|||
// InInt always returns value without error,
|
|||
// it returns default value if error occurs or doesn't fit into candidates.
|
|||
func (k *Key) InInt(defaultVal int, candidates []int) int { |
|||
val := k.MustInt() |
|||
for _, cand := range candidates { |
|||
if val == cand { |
|||
return val |
|||
} |
|||
} |
|||
return defaultVal |
|||
} |
|||
|
|||
// InInt64 always returns value without error,
|
|||
// it returns default value if error occurs or doesn't fit into candidates.
|
|||
func (k *Key) InInt64(defaultVal int64, candidates []int64) int64 { |
|||
val := k.MustInt64() |
|||
for _, cand := range candidates { |
|||
if val == cand { |
|||
return val |
|||
} |
|||
} |
|||
return defaultVal |
|||
} |
|||
|
|||
// InUint always returns value without error,
|
|||
// it returns default value if error occurs or doesn't fit into candidates.
|
|||
func (k *Key) InUint(defaultVal uint, candidates []uint) uint { |
|||
val := k.MustUint() |
|||
for _, cand := range candidates { |
|||
if val == cand { |
|||
return val |
|||
} |
|||
} |
|||
return defaultVal |
|||
} |
|||
|
|||
// InUint64 always returns value without error,
|
|||
// it returns default value if error occurs or doesn't fit into candidates.
|
|||
func (k *Key) InUint64(defaultVal uint64, candidates []uint64) uint64 { |
|||
val := k.MustUint64() |
|||
for _, cand := range candidates { |
|||
if val == cand { |
|||
return val |
|||
} |
|||
} |
|||
return defaultVal |
|||
} |
|||
|
|||
// InTimeFormat always parses with given format and returns value without error,
|
|||
// it returns default value if error occurs or doesn't fit into candidates.
|
|||
func (k *Key) InTimeFormat(format string, defaultVal time.Time, candidates []time.Time) time.Time { |
|||
val := k.MustTimeFormat(format) |
|||
for _, cand := range candidates { |
|||
if val == cand { |
|||
return val |
|||
} |
|||
} |
|||
return defaultVal |
|||
} |
|||
|
|||
// InTime always parses with RFC3339 format and returns value without error,
|
|||
// it returns default value if error occurs or doesn't fit into candidates.
|
|||
func (k *Key) InTime(defaultVal time.Time, candidates []time.Time) time.Time { |
|||
return k.InTimeFormat(time.RFC3339, defaultVal, candidates) |
|||
} |
|||
|
|||
// RangeFloat64 checks if value is in given range inclusively,
|
|||
// and returns default value if it's not.
|
|||
func (k *Key) RangeFloat64(defaultVal, min, max float64) float64 { |
|||
val := k.MustFloat64() |
|||
if val < min || val > max { |
|||
return defaultVal |
|||
} |
|||
return val |
|||
} |
|||
|
|||
// RangeInt checks if value is in given range inclusively,
|
|||
// and returns default value if it's not.
|
|||
func (k *Key) RangeInt(defaultVal, min, max int) int { |
|||
val := k.MustInt() |
|||
if val < min || val > max { |
|||
return defaultVal |
|||
} |
|||
return val |
|||
} |
|||
|
|||
// RangeInt64 checks if value is in given range inclusively,
|
|||
// and returns default value if it's not.
|
|||
func (k *Key) RangeInt64(defaultVal, min, max int64) int64 { |
|||
val := k.MustInt64() |
|||
if val < min || val > max { |
|||
return defaultVal |
|||
} |
|||
return val |
|||
} |
|||
|
|||
// RangeTimeFormat checks if value with given format is in given range inclusively,
|
|||
// and returns default value if it's not.
|
|||
func (k *Key) RangeTimeFormat(format string, defaultVal, min, max time.Time) time.Time { |
|||
val := k.MustTimeFormat(format) |
|||
if val.Unix() < min.Unix() || val.Unix() > max.Unix() { |
|||
return defaultVal |
|||
} |
|||
return val |
|||
} |
|||
|
|||
// RangeTime checks if value with RFC3339 format is in given range inclusively,
|
|||
// and returns default value if it's not.
|
|||
func (k *Key) RangeTime(defaultVal, min, max time.Time) time.Time { |
|||
return k.RangeTimeFormat(time.RFC3339, defaultVal, min, max) |
|||
} |
|||
|
|||
// Strings returns list of string divided by given delimiter.
|
|||
func (k *Key) Strings(delim string) []string { |
|||
str := k.String() |
|||
if len(str) == 0 { |
|||
return []string{} |
|||
} |
|||
|
|||
vals := strings.Split(str, delim) |
|||
for i := range vals { |
|||
vals[i] = strings.TrimSpace(vals[i]) |
|||
} |
|||
return vals |
|||
} |
|||
|
|||
// Float64s returns list of float64 divided by given delimiter. Any invalid input will be treated as zero value.
|
|||
func (k *Key) Float64s(delim string) []float64 { |
|||
vals, _ := k.getFloat64s(delim, true, false) |
|||
return vals |
|||
} |
|||
|
|||
// Ints returns list of int divided by given delimiter. Any invalid input will be treated as zero value.
|
|||
func (k *Key) Ints(delim string) []int { |
|||
vals, _ := k.getInts(delim, true, false) |
|||
return vals |
|||
} |
|||
|
|||
// Int64s returns list of int64 divided by given delimiter. Any invalid input will be treated as zero value.
|
|||
func (k *Key) Int64s(delim string) []int64 { |
|||
vals, _ := k.getInt64s(delim, true, false) |
|||
return vals |
|||
} |
|||
|
|||
// Uints returns list of uint divided by given delimiter. Any invalid input will be treated as zero value.
|
|||
func (k *Key) Uints(delim string) []uint { |
|||
vals, _ := k.getUints(delim, true, false) |
|||
return vals |
|||
} |
|||
|
|||
// Uint64s returns list of uint64 divided by given delimiter. Any invalid input will be treated as zero value.
|
|||
func (k *Key) Uint64s(delim string) []uint64 { |
|||
vals, _ := k.getUint64s(delim, true, false) |
|||
return vals |
|||
} |
|||
|
|||
// TimesFormat parses with given format and returns list of time.Time divided by given delimiter.
|
|||
// Any invalid input will be treated as zero value (0001-01-01 00:00:00 +0000 UTC).
|
|||
func (k *Key) TimesFormat(format, delim string) []time.Time { |
|||
vals, _ := k.getTimesFormat(format, delim, true, false) |
|||
return vals |
|||
} |
|||
|
|||
// Times parses with RFC3339 format and returns list of time.Time divided by given delimiter.
|
|||
// Any invalid input will be treated as zero value (0001-01-01 00:00:00 +0000 UTC).
|
|||
func (k *Key) Times(delim string) []time.Time { |
|||
return k.TimesFormat(time.RFC3339, delim) |
|||
} |
|||
|
|||
// ValidFloat64s returns list of float64 divided by given delimiter. If some value is not float, then
|
|||
// it will not be included to result list.
|
|||
func (k *Key) ValidFloat64s(delim string) []float64 { |
|||
vals, _ := k.getFloat64s(delim, false, false) |
|||
return vals |
|||
} |
|||
|
|||
// ValidInts returns list of int divided by given delimiter. If some value is not integer, then it will
|
|||
// not be included to result list.
|
|||
func (k *Key) ValidInts(delim string) []int { |
|||
vals, _ := k.getInts(delim, false, false) |
|||
return vals |
|||
} |
|||
|
|||
// ValidInt64s returns list of int64 divided by given delimiter. If some value is not 64-bit integer,
|
|||
// then it will not be included to result list.
|
|||
func (k *Key) ValidInt64s(delim string) []int64 { |
|||
vals, _ := k.getInt64s(delim, false, false) |
|||
return vals |
|||
} |
|||
|
|||
// ValidUints returns list of uint divided by given delimiter. If some value is not unsigned integer,
|
|||
// then it will not be included to result list.
|
|||
func (k *Key) ValidUints(delim string) []uint { |
|||
vals, _ := k.getUints(delim, false, false) |
|||
return vals |
|||
} |
|||
|
|||
// ValidUint64s returns list of uint64 divided by given delimiter. If some value is not 64-bit unsigned
|
|||
// integer, then it will not be included to result list.
|
|||
func (k *Key) ValidUint64s(delim string) []uint64 { |
|||
vals, _ := k.getUint64s(delim, false, false) |
|||
return vals |
|||
} |
|||
|
|||
// ValidTimesFormat parses with given format and returns list of time.Time divided by given delimiter.
|
|||
func (k *Key) ValidTimesFormat(format, delim string) []time.Time { |
|||
vals, _ := k.getTimesFormat(format, delim, false, false) |
|||
return vals |
|||
} |
|||
|
|||
// ValidTimes parses with RFC3339 format and returns list of time.Time divided by given delimiter.
|
|||
func (k *Key) ValidTimes(delim string) []time.Time { |
|||
return k.ValidTimesFormat(time.RFC3339, delim) |
|||
} |
|||
|
|||
// StrictFloat64s returns list of float64 divided by given delimiter or error on first invalid input.
|
|||
func (k *Key) StrictFloat64s(delim string) ([]float64, error) { |
|||
return k.getFloat64s(delim, false, true) |
|||
} |
|||
|
|||
// StrictInts returns list of int divided by given delimiter or error on first invalid input.
|
|||
func (k *Key) StrictInts(delim string) ([]int, error) { |
|||
return k.getInts(delim, false, true) |
|||
} |
|||
|
|||
// StrictInt64s returns list of int64 divided by given delimiter or error on first invalid input.
|
|||
func (k *Key) StrictInt64s(delim string) ([]int64, error) { |
|||
return k.getInt64s(delim, false, true) |
|||
} |
|||
|
|||
// StrictUints returns list of uint divided by given delimiter or error on first invalid input.
|
|||
func (k *Key) StrictUints(delim string) ([]uint, error) { |
|||
return k.getUints(delim, false, true) |
|||
} |
|||
|
|||
// StrictUint64s returns list of uint64 divided by given delimiter or error on first invalid input.
|
|||
func (k *Key) StrictUint64s(delim string) ([]uint64, error) { |
|||
return k.getUint64s(delim, false, true) |
|||
} |
|||
|
|||
// StrictTimesFormat parses with given format and returns list of time.Time divided by given delimiter
|
|||
// or error on first invalid input.
|
|||
func (k *Key) StrictTimesFormat(format, delim string) ([]time.Time, error) { |
|||
return k.getTimesFormat(format, delim, false, true) |
|||
} |
|||
|
|||
// StrictTimes parses with RFC3339 format and returns list of time.Time divided by given delimiter
|
|||
// or error on first invalid input.
|
|||
func (k *Key) StrictTimes(delim string) ([]time.Time, error) { |
|||
return k.StrictTimesFormat(time.RFC3339, delim) |
|||
} |
|||
|
|||
// getFloat64s returns list of float64 divided by given delimiter.
|
|||
func (k *Key) getFloat64s(delim string, addInvalid, returnOnInvalid bool) ([]float64, error) { |
|||
strs := k.Strings(delim) |
|||
vals := make([]float64, 0, len(strs)) |
|||
for _, str := range strs { |
|||
val, err := strconv.ParseFloat(str, 64) |
|||
if err != nil && returnOnInvalid { |
|||
return nil, err |
|||
} |
|||
if err == nil || addInvalid { |
|||
vals = append(vals, val) |
|||
} |
|||
} |
|||
return vals, nil |
|||
} |
|||
|
|||
// getInts returns list of int divided by given delimiter.
|
|||
func (k *Key) getInts(delim string, addInvalid, returnOnInvalid bool) ([]int, error) { |
|||
strs := k.Strings(delim) |
|||
vals := make([]int, 0, len(strs)) |
|||
for _, str := range strs { |
|||
val, err := strconv.Atoi(str) |
|||
if err != nil && returnOnInvalid { |
|||
return nil, err |
|||
} |
|||
if err == nil || addInvalid { |
|||
vals = append(vals, val) |
|||
} |
|||
} |
|||
return vals, nil |
|||
} |
|||
|
|||
// getInt64s returns list of int64 divided by given delimiter.
|
|||
func (k *Key) getInt64s(delim string, addInvalid, returnOnInvalid bool) ([]int64, error) { |
|||
strs := k.Strings(delim) |
|||
vals := make([]int64, 0, len(strs)) |
|||
for _, str := range strs { |
|||
val, err := strconv.ParseInt(str, 10, 64) |
|||
if err != nil && returnOnInvalid { |
|||
return nil, err |
|||
} |
|||
if err == nil || addInvalid { |
|||
vals = append(vals, val) |
|||
} |
|||
} |
|||
return vals, nil |
|||
} |
|||
|
|||
// getUints returns list of uint divided by given delimiter.
|
|||
func (k *Key) getUints(delim string, addInvalid, returnOnInvalid bool) ([]uint, error) { |
|||
strs := k.Strings(delim) |
|||
vals := make([]uint, 0, len(strs)) |
|||
for _, str := range strs { |
|||
val, err := strconv.ParseUint(str, 10, 0) |
|||
if err != nil && returnOnInvalid { |
|||
return nil, err |
|||
} |
|||
if err == nil || addInvalid { |
|||
vals = append(vals, uint(val)) |
|||
} |
|||
} |
|||
return vals, nil |
|||
} |
|||
|
|||
// getUint64s returns list of uint64 divided by given delimiter.
|
|||
func (k *Key) getUint64s(delim string, addInvalid, returnOnInvalid bool) ([]uint64, error) { |
|||
strs := k.Strings(delim) |
|||
vals := make([]uint64, 0, len(strs)) |
|||
for _, str := range strs { |
|||
val, err := strconv.ParseUint(str, 10, 64) |
|||
if err != nil && returnOnInvalid { |
|||
return nil, err |
|||
} |
|||
if err == nil || addInvalid { |
|||
vals = append(vals, val) |
|||
} |
|||
} |
|||
return vals, nil |
|||
} |
|||
|
|||
// getTimesFormat parses with given format and returns list of time.Time divided by given delimiter.
|
|||
func (k *Key) getTimesFormat(format, delim string, addInvalid, returnOnInvalid bool) ([]time.Time, error) { |
|||
strs := k.Strings(delim) |
|||
vals := make([]time.Time, 0, len(strs)) |
|||
for _, str := range strs { |
|||
val, err := time.Parse(format, str) |
|||
if err != nil && returnOnInvalid { |
|||
return nil, err |
|||
} |
|||
if err == nil || addInvalid { |
|||
vals = append(vals, val) |
|||
} |
|||
} |
|||
return vals, nil |
|||
} |
|||
|
|||
// SetValue changes key value.
|
|||
func (k *Key) SetValue(v string) { |
|||
if k.s.f.BlockMode { |
|||
k.s.f.lock.Lock() |
|||
defer k.s.f.lock.Unlock() |
|||
} |
|||
|
|||
k.value = v |
|||
k.s.keysHash[k.name] = v |
|||
} |
@ -0,0 +1,325 @@ |
|||
// Copyright 2015 Unknwon
|
|||
//
|
|||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
|||
// not use this file except in compliance with the License. You may obtain
|
|||
// a copy of the License at
|
|||
//
|
|||
// http://www.apache.org/licenses/LICENSE-2.0
|
|||
//
|
|||
// Unless required by applicable law or agreed to in writing, software
|
|||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|||
// License for the specific language governing permissions and limitations
|
|||
// under the License.
|
|||
|
|||
package ini |
|||
|
|||
import ( |
|||
"bufio" |
|||
"bytes" |
|||
"fmt" |
|||
"io" |
|||
"strconv" |
|||
"strings" |
|||
"unicode" |
|||
) |
|||
|
|||
type tokenType int |
|||
|
|||
const ( |
|||
_TOKEN_INVALID tokenType = iota |
|||
_TOKEN_COMMENT |
|||
_TOKEN_SECTION |
|||
_TOKEN_KEY |
|||
) |
|||
|
|||
type parser struct { |
|||
buf *bufio.Reader |
|||
isEOF bool |
|||
count int |
|||
comment *bytes.Buffer |
|||
} |
|||
|
|||
func newParser(r io.Reader) *parser { |
|||
return &parser{ |
|||
buf: bufio.NewReader(r), |
|||
count: 1, |
|||
comment: &bytes.Buffer{}, |
|||
} |
|||
} |
|||
|
|||
// BOM handles header of BOM-UTF8 format.
|
|||
// http://en.wikipedia.org/wiki/Byte_order_mark#Representations_of_byte_order_marks_by_encoding
|
|||
func (p *parser) BOM() error { |
|||
mask, err := p.buf.Peek(3) |
|||
if err != nil && err != io.EOF { |
|||
return err |
|||
} else if len(mask) < 3 { |
|||
return nil |
|||
} else if mask[0] == 239 && mask[1] == 187 && mask[2] == 191 { |
|||
p.buf.Read(mask) |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func (p *parser) readUntil(delim byte) ([]byte, error) { |
|||
data, err := p.buf.ReadBytes(delim) |
|||
if err != nil { |
|||
if err == io.EOF { |
|||
p.isEOF = true |
|||
} else { |
|||
return nil, err |
|||
} |
|||
} |
|||
return data, nil |
|||
} |
|||
|
|||
func cleanComment(in []byte) ([]byte, bool) { |
|||
i := bytes.IndexAny(in, "#;") |
|||
if i == -1 { |
|||
return nil, false |
|||
} |
|||
return in[i:], true |
|||
} |
|||
|
|||
func readKeyName(in []byte) (string, int, error) { |
|||
line := string(in) |
|||
|
|||
// Check if key name surrounded by quotes.
|
|||
var keyQuote string |
|||
if line[0] == '"' { |
|||
if len(line) > 6 && string(line[0:3]) == `"""` { |
|||
keyQuote = `"""` |
|||
} else { |
|||
keyQuote = `"` |
|||
} |
|||
} else if line[0] == '`' { |
|||
keyQuote = "`" |
|||
} |
|||
|
|||
// Get out key name
|
|||
endIdx := -1 |
|||
if len(keyQuote) > 0 { |
|||
startIdx := len(keyQuote) |
|||
// FIXME: fail case -> """"""name"""=value
|
|||
pos := strings.Index(line[startIdx:], keyQuote) |
|||
if pos == -1 { |
|||
return "", -1, fmt.Errorf("missing closing key quote: %s", line) |
|||
} |
|||
pos += startIdx |
|||
|
|||
// Find key-value delimiter
|
|||
i := strings.IndexAny(line[pos+startIdx:], "=:") |
|||
if i < 0 { |
|||
return "", -1, ErrDelimiterNotFound{line} |
|||
} |
|||
endIdx = pos + i |
|||
return strings.TrimSpace(line[startIdx:pos]), endIdx + startIdx + 1, nil |
|||
} |
|||
|
|||
endIdx = strings.IndexAny(line, "=:") |
|||
if endIdx < 0 { |
|||
return "", -1, ErrDelimiterNotFound{line} |
|||
} |
|||
return strings.TrimSpace(line[0:endIdx]), endIdx + 1, nil |
|||
} |
|||
|
|||
func (p *parser) readMultilines(line, val, valQuote string) (string, error) { |
|||
for { |
|||
data, err := p.readUntil('\n') |
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
next := string(data) |
|||
|
|||
pos := strings.LastIndex(next, valQuote) |
|||
if pos > -1 { |
|||
val += next[:pos] |
|||
|
|||
comment, has := cleanComment([]byte(next[pos:])) |
|||
if has { |
|||
p.comment.Write(bytes.TrimSpace(comment)) |
|||
} |
|||
break |
|||
} |
|||
val += next |
|||
if p.isEOF { |
|||
return "", fmt.Errorf("missing closing key quote from '%s' to '%s'", line, next) |
|||
} |
|||
} |
|||
return val, nil |
|||
} |
|||
|
|||
func (p *parser) readContinuationLines(val string) (string, error) { |
|||
for { |
|||
data, err := p.readUntil('\n') |
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
next := strings.TrimSpace(string(data)) |
|||
|
|||
if len(next) == 0 { |
|||
break |
|||
} |
|||
val += next |
|||
if val[len(val)-1] != '\\' { |
|||
break |
|||
} |
|||
val = val[:len(val)-1] |
|||
} |
|||
return val, nil |
|||
} |
|||
|
|||
// hasSurroundedQuote check if and only if the first and last characters
|
|||
// are quotes \" or \'.
|
|||
// It returns false if any other parts also contain same kind of quotes.
|
|||
func hasSurroundedQuote(in string, quote byte) bool { |
|||
return len(in) > 2 && in[0] == quote && in[len(in)-1] == quote && |
|||
strings.IndexByte(in[1:], quote) == len(in)-2 |
|||
} |
|||
|
|||
func (p *parser) readValue(in []byte, ignoreContinuation bool) (string, error) { |
|||
line := strings.TrimLeftFunc(string(in), unicode.IsSpace) |
|||
if len(line) == 0 { |
|||
return "", nil |
|||
} |
|||
|
|||
var valQuote string |
|||
if len(line) > 3 && string(line[0:3]) == `"""` { |
|||
valQuote = `"""` |
|||
} else if line[0] == '`' { |
|||
valQuote = "`" |
|||
} |
|||
|
|||
if len(valQuote) > 0 { |
|||
startIdx := len(valQuote) |
|||
pos := strings.LastIndex(line[startIdx:], valQuote) |
|||
// Check for multi-line value
|
|||
if pos == -1 { |
|||
return p.readMultilines(line, line[startIdx:], valQuote) |
|||
} |
|||
|
|||
return line[startIdx : pos+startIdx], nil |
|||
} |
|||
|
|||
// Won't be able to reach here if value only contains whitespace.
|
|||
line = strings.TrimSpace(line) |
|||
|
|||
// Check continuation lines when desired.
|
|||
if !ignoreContinuation && line[len(line)-1] == '\\' { |
|||
return p.readContinuationLines(line[:len(line)-1]) |
|||
} |
|||
|
|||
i := strings.IndexAny(line, "#;") |
|||
if i > -1 { |
|||
p.comment.WriteString(line[i:]) |
|||
line = strings.TrimSpace(line[:i]) |
|||
} |
|||
|
|||
// Trim single quotes
|
|||
if hasSurroundedQuote(line, '\'') || |
|||
hasSurroundedQuote(line, '"') { |
|||
line = line[1 : len(line)-1] |
|||
} |
|||
return line, nil |
|||
} |
|||
|
|||
// parse parses data through an io.Reader.
|
|||
func (f *File) parse(reader io.Reader) (err error) { |
|||
p := newParser(reader) |
|||
if err = p.BOM(); err != nil { |
|||
return fmt.Errorf("BOM: %v", err) |
|||
} |
|||
|
|||
// Ignore error because default section name is never empty string.
|
|||
section, _ := f.NewSection(DEFAULT_SECTION) |
|||
|
|||
var line []byte |
|||
for !p.isEOF { |
|||
line, err = p.readUntil('\n') |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
line = bytes.TrimLeftFunc(line, unicode.IsSpace) |
|||
if len(line) == 0 { |
|||
continue |
|||
} |
|||
|
|||
// Comments
|
|||
if line[0] == '#' || line[0] == ';' { |
|||
// Note: we do not care ending line break,
|
|||
// it is needed for adding second line,
|
|||
// so just clean it once at the end when set to value.
|
|||
p.comment.Write(line) |
|||
continue |
|||
} |
|||
|
|||
// Section
|
|||
if line[0] == '[' { |
|||
// Read to the next ']' (TODO: support quoted strings)
|
|||
// TODO(unknwon): use LastIndexByte when stop supporting Go1.4
|
|||
closeIdx := bytes.LastIndex(line, []byte("]")) |
|||
if closeIdx == -1 { |
|||
return fmt.Errorf("unclosed section: %s", line) |
|||
} |
|||
|
|||
name := string(line[1:closeIdx]) |
|||
section, err = f.NewSection(name) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
comment, has := cleanComment(line[closeIdx+1:]) |
|||
if has { |
|||
p.comment.Write(comment) |
|||
} |
|||
|
|||
section.Comment = strings.TrimSpace(p.comment.String()) |
|||
|
|||
// Reset aotu-counter and comments
|
|||
p.comment.Reset() |
|||
p.count = 1 |
|||
continue |
|||
} |
|||
|
|||
kname, offset, err := readKeyName(line) |
|||
if err != nil { |
|||
// Treat as boolean key when desired, and whole line is key name.
|
|||
if IsErrDelimiterNotFound(err) && f.options.AllowBooleanKeys { |
|||
key, err := section.NewKey(string(line), "true") |
|||
if err != nil { |
|||
return err |
|||
} |
|||
key.isBooleanType = true |
|||
key.Comment = strings.TrimSpace(p.comment.String()) |
|||
p.comment.Reset() |
|||
continue |
|||
} |
|||
return err |
|||
} |
|||
|
|||
// Auto increment.
|
|||
isAutoIncr := false |
|||
if kname == "-" { |
|||
isAutoIncr = true |
|||
kname = "#" + strconv.Itoa(p.count) |
|||
p.count++ |
|||
} |
|||
|
|||
key, err := section.NewKey(kname, "") |
|||
if err != nil { |
|||
return err |
|||
} |
|||
key.isAutoIncrement = isAutoIncr |
|||
|
|||
value, err := p.readValue(line[offset:], f.options.IgnoreContinuation) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
key.SetValue(value) |
|||
key.Comment = strings.TrimSpace(p.comment.String()) |
|||
p.comment.Reset() |
|||
} |
|||
return nil |
|||
} |
@ -0,0 +1,206 @@ |
|||
// Copyright 2014 Unknwon
|
|||
//
|
|||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
|||
// not use this file except in compliance with the License. You may obtain
|
|||
// a copy of the License at
|
|||
//
|
|||
// http://www.apache.org/licenses/LICENSE-2.0
|
|||
//
|
|||
// Unless required by applicable law or agreed to in writing, software
|
|||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|||
// License for the specific language governing permissions and limitations
|
|||
// under the License.
|
|||
|
|||
package ini |
|||
|
|||
import ( |
|||
"errors" |
|||
"fmt" |
|||
"strings" |
|||
) |
|||
|
|||
// Section represents a config section.
|
|||
type Section struct { |
|||
f *File |
|||
Comment string |
|||
name string |
|||
keys map[string]*Key |
|||
keyList []string |
|||
keysHash map[string]string |
|||
} |
|||
|
|||
func newSection(f *File, name string) *Section { |
|||
return &Section{f, "", name, make(map[string]*Key), make([]string, 0, 10), make(map[string]string)} |
|||
} |
|||
|
|||
// Name returns name of Section.
|
|||
func (s *Section) Name() string { |
|||
return s.name |
|||
} |
|||
|
|||
// NewKey creates a new key to given section.
|
|||
func (s *Section) NewKey(name, val string) (*Key, error) { |
|||
if len(name) == 0 { |
|||
return nil, errors.New("error creating new key: empty key name") |
|||
} else if s.f.options.Insensitive { |
|||
name = strings.ToLower(name) |
|||
} |
|||
|
|||
if s.f.BlockMode { |
|||
s.f.lock.Lock() |
|||
defer s.f.lock.Unlock() |
|||
} |
|||
|
|||
if inSlice(name, s.keyList) { |
|||
s.keys[name].value = val |
|||
return s.keys[name], nil |
|||
} |
|||
|
|||
s.keyList = append(s.keyList, name) |
|||
s.keys[name] = &Key{ |
|||
s: s, |
|||
name: name, |
|||
value: val, |
|||
} |
|||
s.keysHash[name] = val |
|||
return s.keys[name], nil |
|||
} |
|||
|
|||
// GetKey returns key in section by given name.
|
|||
func (s *Section) GetKey(name string) (*Key, error) { |
|||
// FIXME: change to section level lock?
|
|||
if s.f.BlockMode { |
|||
s.f.lock.RLock() |
|||
} |
|||
if s.f.options.Insensitive { |
|||
name = strings.ToLower(name) |
|||
} |
|||
key := s.keys[name] |
|||
if s.f.BlockMode { |
|||
s.f.lock.RUnlock() |
|||
} |
|||
|
|||
if key == nil { |
|||
// Check if it is a child-section.
|
|||
sname := s.name |
|||
for { |
|||
if i := strings.LastIndex(sname, "."); i > -1 { |
|||
sname = sname[:i] |
|||
sec, err := s.f.GetSection(sname) |
|||
if err != nil { |
|||
continue |
|||
} |
|||
return sec.GetKey(name) |
|||
} else { |
|||
break |
|||
} |
|||
} |
|||
return nil, fmt.Errorf("error when getting key of section '%s': key '%s' not exists", s.name, name) |
|||
} |
|||
return key, nil |
|||
} |
|||
|
|||
// HasKey returns true if section contains a key with given name.
|
|||
func (s *Section) HasKey(name string) bool { |
|||
key, _ := s.GetKey(name) |
|||
return key != nil |
|||
} |
|||
|
|||
// Haskey is a backwards-compatible name for HasKey.
|
|||
func (s *Section) Haskey(name string) bool { |
|||
return s.HasKey(name) |
|||
} |
|||
|
|||
// HasValue returns true if section contains given raw value.
|
|||
func (s *Section) HasValue(value string) bool { |
|||
if s.f.BlockMode { |
|||
s.f.lock.RLock() |
|||
defer s.f.lock.RUnlock() |
|||
} |
|||
|
|||
for _, k := range s.keys { |
|||
if value == k.value { |
|||
return true |
|||
} |
|||
} |
|||
return false |
|||
} |
|||
|
|||
// Key assumes named Key exists in section and returns a zero-value when not.
|
|||
func (s *Section) Key(name string) *Key { |
|||
key, err := s.GetKey(name) |
|||
if err != nil { |
|||
// It's OK here because the only possible error is empty key name,
|
|||
// but if it's empty, this piece of code won't be executed.
|
|||
key, _ = s.NewKey(name, "") |
|||
return key |
|||
} |
|||
return key |
|||
} |
|||
|
|||
// Keys returns list of keys of section.
|
|||
func (s *Section) Keys() []*Key { |
|||
keys := make([]*Key, len(s.keyList)) |
|||
for i := range s.keyList { |
|||
keys[i] = s.Key(s.keyList[i]) |
|||
} |
|||
return keys |
|||
} |
|||
|
|||
// ParentKeys returns list of keys of parent section.
|
|||
func (s *Section) ParentKeys() []*Key { |
|||
var parentKeys []*Key |
|||
sname := s.name |
|||
for { |
|||
if i := strings.LastIndex(sname, "."); i > -1 { |
|||
sname = sname[:i] |
|||
sec, err := s.f.GetSection(sname) |
|||
if err != nil { |
|||
continue |
|||
} |
|||
parentKeys = append(parentKeys, sec.Keys()...) |
|||
} else { |
|||
break |
|||
} |
|||
|
|||
} |
|||
return parentKeys |
|||
} |
|||
|
|||
// KeyStrings returns list of key names of section.
|
|||
func (s *Section) KeyStrings() []string { |
|||
list := make([]string, len(s.keyList)) |
|||
copy(list, s.keyList) |
|||
return list |
|||
} |
|||
|
|||
// KeysHash returns keys hash consisting of names and values.
|
|||
func (s *Section) KeysHash() map[string]string { |
|||
if s.f.BlockMode { |
|||
s.f.lock.RLock() |
|||
defer s.f.lock.RUnlock() |
|||
} |
|||
|
|||
hash := map[string]string{} |
|||
for key, value := range s.keysHash { |
|||
hash[key] = value |
|||
} |
|||
return hash |
|||
} |
|||
|
|||
// DeleteKey deletes a key from section.
|
|||
func (s *Section) DeleteKey(name string) { |
|||
if s.f.BlockMode { |
|||
s.f.lock.Lock() |
|||
defer s.f.lock.Unlock() |
|||
} |
|||
|
|||
for i, k := range s.keyList { |
|||
if k == name { |
|||
s.keyList = append(s.keyList[:i], s.keyList[i+1:]...) |
|||
delete(s.keys, name) |
|||
return |
|||
} |
|||
} |
|||
} |
@ -0,0 +1,431 @@ |
|||
// Copyright 2014 Unknwon
|
|||
//
|
|||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
|||
// not use this file except in compliance with the License. You may obtain
|
|||
// a copy of the License at
|
|||
//
|
|||
// http://www.apache.org/licenses/LICENSE-2.0
|
|||
//
|
|||
// Unless required by applicable law or agreed to in writing, software
|
|||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|||
// License for the specific language governing permissions and limitations
|
|||
// under the License.
|
|||
|
|||
package ini |
|||
|
|||
import ( |
|||
"bytes" |
|||
"errors" |
|||
"fmt" |
|||
"reflect" |
|||
"strings" |
|||
"time" |
|||
"unicode" |
|||
) |
|||
|
|||
// NameMapper represents a ini tag name mapper.
|
|||
type NameMapper func(string) string |
|||
|
|||
// Built-in name getters.
|
|||
var ( |
|||
// AllCapsUnderscore converts to format ALL_CAPS_UNDERSCORE.
|
|||
AllCapsUnderscore NameMapper = func(raw string) string { |
|||
newstr := make([]rune, 0, len(raw)) |
|||
for i, chr := range raw { |
|||
if isUpper := 'A' <= chr && chr <= 'Z'; isUpper { |
|||
if i > 0 { |
|||
newstr = append(newstr, '_') |
|||
} |
|||
} |
|||
newstr = append(newstr, unicode.ToUpper(chr)) |
|||
} |
|||
return string(newstr) |
|||
} |
|||
// TitleUnderscore converts to format title_underscore.
|
|||
TitleUnderscore NameMapper = func(raw string) string { |
|||
newstr := make([]rune, 0, len(raw)) |
|||
for i, chr := range raw { |
|||
if isUpper := 'A' <= chr && chr <= 'Z'; isUpper { |
|||
if i > 0 { |
|||
newstr = append(newstr, '_') |
|||
} |
|||
chr -= ('A' - 'a') |
|||
} |
|||
newstr = append(newstr, chr) |
|||
} |
|||
return string(newstr) |
|||
} |
|||
) |
|||
|
|||
func (s *Section) parseFieldName(raw, actual string) string { |
|||
if len(actual) > 0 { |
|||
return actual |
|||
} |
|||
if s.f.NameMapper != nil { |
|||
return s.f.NameMapper(raw) |
|||
} |
|||
return raw |
|||
} |
|||
|
|||
func parseDelim(actual string) string { |
|||
if len(actual) > 0 { |
|||
return actual |
|||
} |
|||
return "," |
|||
} |
|||
|
|||
var reflectTime = reflect.TypeOf(time.Now()).Kind() |
|||
|
|||
// setSliceWithProperType sets proper values to slice based on its type.
|
|||
func setSliceWithProperType(key *Key, field reflect.Value, delim string) error { |
|||
strs := key.Strings(delim) |
|||
numVals := len(strs) |
|||
if numVals == 0 { |
|||
return nil |
|||
} |
|||
|
|||
var vals interface{} |
|||
|
|||
sliceOf := field.Type().Elem().Kind() |
|||
switch sliceOf { |
|||
case reflect.String: |
|||
vals = strs |
|||
case reflect.Int: |
|||
vals = key.Ints(delim) |
|||
case reflect.Int64: |
|||
vals = key.Int64s(delim) |
|||
case reflect.Uint: |
|||
vals = key.Uints(delim) |
|||
case reflect.Uint64: |
|||
vals = key.Uint64s(delim) |
|||
case reflect.Float64: |
|||
vals = key.Float64s(delim) |
|||
case reflectTime: |
|||
vals = key.Times(delim) |
|||
default: |
|||
return fmt.Errorf("unsupported type '[]%s'", sliceOf) |
|||
} |
|||
|
|||
slice := reflect.MakeSlice(field.Type(), numVals, numVals) |
|||
for i := 0; i < numVals; i++ { |
|||
switch sliceOf { |
|||
case reflect.String: |
|||
slice.Index(i).Set(reflect.ValueOf(vals.([]string)[i])) |
|||
case reflect.Int: |
|||
slice.Index(i).Set(reflect.ValueOf(vals.([]int)[i])) |
|||
case reflect.Int64: |
|||
slice.Index(i).Set(reflect.ValueOf(vals.([]int64)[i])) |
|||
case reflect.Uint: |
|||
slice.Index(i).Set(reflect.ValueOf(vals.([]uint)[i])) |
|||
case reflect.Uint64: |
|||
slice.Index(i).Set(reflect.ValueOf(vals.([]uint64)[i])) |
|||
case reflect.Float64: |
|||
slice.Index(i).Set(reflect.ValueOf(vals.([]float64)[i])) |
|||
case reflectTime: |
|||
slice.Index(i).Set(reflect.ValueOf(vals.([]time.Time)[i])) |
|||
} |
|||
} |
|||
field.Set(slice) |
|||
return nil |
|||
} |
|||
|
|||
// setWithProperType sets proper value to field based on its type,
|
|||
// but it does not return error for failing parsing,
|
|||
// because we want to use default value that is already assigned to strcut.
|
|||
func setWithProperType(t reflect.Type, key *Key, field reflect.Value, delim string) error { |
|||
switch t.Kind() { |
|||
case reflect.String: |
|||
if len(key.String()) == 0 { |
|||
return nil |
|||
} |
|||
field.SetString(key.String()) |
|||
case reflect.Bool: |
|||
boolVal, err := key.Bool() |
|||
if err != nil { |
|||
return nil |
|||
} |
|||
field.SetBool(boolVal) |
|||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
|||
durationVal, err := key.Duration() |
|||
// Skip zero value
|
|||
if err == nil && int(durationVal) > 0 { |
|||
field.Set(reflect.ValueOf(durationVal)) |
|||
return nil |
|||
} |
|||
|
|||
intVal, err := key.Int64() |
|||
if err != nil || intVal == 0 { |
|||
return nil |
|||
} |
|||
field.SetInt(intVal) |
|||
// byte is an alias for uint8, so supporting uint8 breaks support for byte
|
|||
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
|||
durationVal, err := key.Duration() |
|||
// Skip zero value
|
|||
if err == nil && int(durationVal) > 0 { |
|||
field.Set(reflect.ValueOf(durationVal)) |
|||
return nil |
|||
} |
|||
|
|||
uintVal, err := key.Uint64() |
|||
if err != nil { |
|||
return nil |
|||
} |
|||
field.SetUint(uintVal) |
|||
|
|||
case reflect.Float64: |
|||
floatVal, err := key.Float64() |
|||
if err != nil { |
|||
return nil |
|||
} |
|||
field.SetFloat(floatVal) |
|||
case reflectTime: |
|||
timeVal, err := key.Time() |
|||
if err != nil { |
|||
return nil |
|||
} |
|||
field.Set(reflect.ValueOf(timeVal)) |
|||
case reflect.Slice: |
|||
return setSliceWithProperType(key, field, delim) |
|||
default: |
|||
return fmt.Errorf("unsupported type '%s'", t) |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func (s *Section) mapTo(val reflect.Value) error { |
|||
if val.Kind() == reflect.Ptr { |
|||
val = val.Elem() |
|||
} |
|||
typ := val.Type() |
|||
|
|||
for i := 0; i < typ.NumField(); i++ { |
|||
field := val.Field(i) |
|||
tpField := typ.Field(i) |
|||
|
|||
tag := tpField.Tag.Get("ini") |
|||
if tag == "-" { |
|||
continue |
|||
} |
|||
|
|||
opts := strings.SplitN(tag, ",", 2) // strip off possible omitempty
|
|||
fieldName := s.parseFieldName(tpField.Name, opts[0]) |
|||
if len(fieldName) == 0 || !field.CanSet() { |
|||
continue |
|||
} |
|||
|
|||
isAnonymous := tpField.Type.Kind() == reflect.Ptr && tpField.Anonymous |
|||
isStruct := tpField.Type.Kind() == reflect.Struct |
|||
if isAnonymous { |
|||
field.Set(reflect.New(tpField.Type.Elem())) |
|||
} |
|||
|
|||
if isAnonymous || isStruct { |
|||
if sec, err := s.f.GetSection(fieldName); err == nil { |
|||
if err = sec.mapTo(field); err != nil { |
|||
return fmt.Errorf("error mapping field(%s): %v", fieldName, err) |
|||
} |
|||
continue |
|||
} |
|||
} |
|||
|
|||
if key, err := s.GetKey(fieldName); err == nil { |
|||
if err = setWithProperType(tpField.Type, key, field, parseDelim(tpField.Tag.Get("delim"))); err != nil { |
|||
return fmt.Errorf("error mapping field(%s): %v", fieldName, err) |
|||
} |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// MapTo maps section to given struct.
|
|||
func (s *Section) MapTo(v interface{}) error { |
|||
typ := reflect.TypeOf(v) |
|||
val := reflect.ValueOf(v) |
|||
if typ.Kind() == reflect.Ptr { |
|||
typ = typ.Elem() |
|||
val = val.Elem() |
|||
} else { |
|||
return errors.New("cannot map to non-pointer struct") |
|||
} |
|||
|
|||
return s.mapTo(val) |
|||
} |
|||
|
|||
// MapTo maps file to given struct.
|
|||
func (f *File) MapTo(v interface{}) error { |
|||
return f.Section("").MapTo(v) |
|||
} |
|||
|
|||
// MapTo maps data sources to given struct with name mapper.
|
|||
func MapToWithMapper(v interface{}, mapper NameMapper, source interface{}, others ...interface{}) error { |
|||
cfg, err := Load(source, others...) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
cfg.NameMapper = mapper |
|||
return cfg.MapTo(v) |
|||
} |
|||
|
|||
// MapTo maps data sources to given struct.
|
|||
func MapTo(v, source interface{}, others ...interface{}) error { |
|||
return MapToWithMapper(v, nil, source, others...) |
|||
} |
|||
|
|||
// reflectSliceWithProperType does the opposite thing as setSliceWithProperType.
|
|||
func reflectSliceWithProperType(key *Key, field reflect.Value, delim string) error { |
|||
slice := field.Slice(0, field.Len()) |
|||
if field.Len() == 0 { |
|||
return nil |
|||
} |
|||
|
|||
var buf bytes.Buffer |
|||
sliceOf := field.Type().Elem().Kind() |
|||
for i := 0; i < field.Len(); i++ { |
|||
switch sliceOf { |
|||
case reflect.String: |
|||
buf.WriteString(slice.Index(i).String()) |
|||
case reflect.Int, reflect.Int64: |
|||
buf.WriteString(fmt.Sprint(slice.Index(i).Int())) |
|||
case reflect.Uint, reflect.Uint64: |
|||
buf.WriteString(fmt.Sprint(slice.Index(i).Uint())) |
|||
case reflect.Float64: |
|||
buf.WriteString(fmt.Sprint(slice.Index(i).Float())) |
|||
case reflectTime: |
|||
buf.WriteString(slice.Index(i).Interface().(time.Time).Format(time.RFC3339)) |
|||
default: |
|||
return fmt.Errorf("unsupported type '[]%s'", sliceOf) |
|||
} |
|||
buf.WriteString(delim) |
|||
} |
|||
key.SetValue(buf.String()[:buf.Len()-1]) |
|||
return nil |
|||
} |
|||
|
|||
// reflectWithProperType does the opposite thing as setWithProperType.
|
|||
func reflectWithProperType(t reflect.Type, key *Key, field reflect.Value, delim string) error { |
|||
switch t.Kind() { |
|||
case reflect.String: |
|||
key.SetValue(field.String()) |
|||
case reflect.Bool: |
|||
key.SetValue(fmt.Sprint(field.Bool())) |
|||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
|||
key.SetValue(fmt.Sprint(field.Int())) |
|||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
|||
key.SetValue(fmt.Sprint(field.Uint())) |
|||
case reflect.Float32, reflect.Float64: |
|||
key.SetValue(fmt.Sprint(field.Float())) |
|||
case reflectTime: |
|||
key.SetValue(fmt.Sprint(field.Interface().(time.Time).Format(time.RFC3339))) |
|||
case reflect.Slice: |
|||
return reflectSliceWithProperType(key, field, delim) |
|||
default: |
|||
return fmt.Errorf("unsupported type '%s'", t) |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// CR: copied from encoding/json/encode.go with modifications of time.Time support.
|
|||
// TODO: add more test coverage.
|
|||
func isEmptyValue(v reflect.Value) bool { |
|||
switch v.Kind() { |
|||
case reflect.Array, reflect.Map, reflect.Slice, reflect.String: |
|||
return v.Len() == 0 |
|||
case reflect.Bool: |
|||
return !v.Bool() |
|||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
|||
return v.Int() == 0 |
|||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: |
|||
return v.Uint() == 0 |
|||
case reflect.Float32, reflect.Float64: |
|||
return v.Float() == 0 |
|||
case reflectTime: |
|||
return v.Interface().(time.Time).IsZero() |
|||
case reflect.Interface, reflect.Ptr: |
|||
return v.IsNil() |
|||
} |
|||
return false |
|||
} |
|||
|
|||
func (s *Section) reflectFrom(val reflect.Value) error { |
|||
if val.Kind() == reflect.Ptr { |
|||
val = val.Elem() |
|||
} |
|||
typ := val.Type() |
|||
|
|||
for i := 0; i < typ.NumField(); i++ { |
|||
field := val.Field(i) |
|||
tpField := typ.Field(i) |
|||
|
|||
tag := tpField.Tag.Get("ini") |
|||
if tag == "-" { |
|||
continue |
|||
} |
|||
|
|||
opts := strings.SplitN(tag, ",", 2) |
|||
if len(opts) == 2 && opts[1] == "omitempty" && isEmptyValue(field) { |
|||
continue |
|||
} |
|||
|
|||
fieldName := s.parseFieldName(tpField.Name, opts[0]) |
|||
if len(fieldName) == 0 || !field.CanSet() { |
|||
continue |
|||
} |
|||
|
|||
if (tpField.Type.Kind() == reflect.Ptr && tpField.Anonymous) || |
|||
(tpField.Type.Kind() == reflect.Struct && tpField.Type.Name() != "Time") { |
|||
// Note: The only error here is section doesn't exist.
|
|||
sec, err := s.f.GetSection(fieldName) |
|||
if err != nil { |
|||
// Note: fieldName can never be empty here, ignore error.
|
|||
sec, _ = s.f.NewSection(fieldName) |
|||
} |
|||
if err = sec.reflectFrom(field); err != nil { |
|||
return fmt.Errorf("error reflecting field (%s): %v", fieldName, err) |
|||
} |
|||
continue |
|||
} |
|||
|
|||
// Note: Same reason as secion.
|
|||
key, err := s.GetKey(fieldName) |
|||
if err != nil { |
|||
key, _ = s.NewKey(fieldName, "") |
|||
} |
|||
if err = reflectWithProperType(tpField.Type, key, field, parseDelim(tpField.Tag.Get("delim"))); err != nil { |
|||
return fmt.Errorf("error reflecting field (%s): %v", fieldName, err) |
|||
} |
|||
|
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// ReflectFrom reflects secion from given struct.
|
|||
func (s *Section) ReflectFrom(v interface{}) error { |
|||
typ := reflect.TypeOf(v) |
|||
val := reflect.ValueOf(v) |
|||
if typ.Kind() == reflect.Ptr { |
|||
typ = typ.Elem() |
|||
val = val.Elem() |
|||
} else { |
|||
return errors.New("cannot reflect from non-pointer struct") |
|||
} |
|||
|
|||
return s.reflectFrom(val) |
|||
} |
|||
|
|||
// ReflectFrom reflects file from given struct.
|
|||
func (f *File) ReflectFrom(v interface{}) error { |
|||
return f.Section("").ReflectFrom(v) |
|||
} |
|||
|
|||
// ReflectFrom reflects data sources from given struct with name mapper.
|
|||
func ReflectFromWithMapper(cfg *File, v interface{}, mapper NameMapper) error { |
|||
cfg.NameMapper = mapper |
|||
return cfg.ReflectFrom(v) |
|||
} |
|||
|
|||
// ReflectFrom reflects data sources from given struct.
|
|||
func ReflectFrom(cfg *File, v interface{}) error { |
|||
return ReflectFromWithMapper(cfg, v, nil) |
|||
} |
@ -0,0 +1,13 @@ |
|||
Copyright 2015 James Saryerwinnie |
|||
|
|||
Licensed under the Apache License, Version 2.0 (the "License"); |
|||
you may not use this file except in compliance with the License. |
|||
You may obtain a copy of the License at |
|||
|
|||
http://www.apache.org/licenses/LICENSE-2.0 |
|||
|
|||
Unless required by applicable law or agreed to in writing, software |
|||
distributed under the License is distributed on an "AS IS" BASIS, |
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
See the License for the specific language governing permissions and |
|||
limitations under the License. |
@ -0,0 +1,44 @@ |
|||
|
|||
CMD = jpgo |
|||
|
|||
help: |
|||
@echo "Please use \`make <target>' where <target> is one of" |
|||
@echo " test to run all the tests" |
|||
@echo " build to build the library and jp executable" |
|||
@echo " generate to run codegen" |
|||
|
|||
|
|||
generate: |
|||
go generate ./... |
|||
|
|||
build: |
|||
rm -f $(CMD) |
|||
go build ./... |
|||
rm -f cmd/$(CMD)/$(CMD) && cd cmd/$(CMD)/ && go build ./... |
|||
mv cmd/$(CMD)/$(CMD) . |
|||
|
|||
test: |
|||
go test -v ./... |
|||
|
|||
check: |
|||
go vet ./... |
|||
@echo "golint ./..." |
|||
@lint=`golint ./...`; \
|
|||
lint=`echo "$$lint" | grep -v "astnodetype_string.go" | grep -v "toktype_string.go"`; \
|
|||
echo "$$lint"; \
|
|||
if [ "$$lint" != "" ]; then exit 1; fi |
|||
|
|||
htmlc: |
|||
go test -coverprofile="/tmp/jpcov" && go tool cover -html="/tmp/jpcov" && unlink /tmp/jpcov |
|||
|
|||
buildfuzz: |
|||
go-fuzz-build github.com/jmespath/go-jmespath/fuzz |
|||
|
|||
fuzz: buildfuzz |
|||
go-fuzz -bin=./jmespath-fuzz.zip -workdir=fuzz/testdata |
|||
|
|||
bench: |
|||
go test -bench . -cpuprofile cpu.out |
|||
|
|||
pprof-cpu: |
|||
go tool pprof ./go-jmespath.test ./cpu.out |
@ -0,0 +1,7 @@ |
|||
# go-jmespath - A JMESPath implementation in Go |
|||
|
|||
[![Build Status](https://img.shields.io/travis/jmespath/go-jmespath.svg)](https://travis-ci.org/jmespath/go-jmespath) |
|||
|
|||
|
|||
|
|||
See http://jmespath.org for more info. |
@ -0,0 +1,49 @@ |
|||
package jmespath |
|||
|
|||
import "strconv" |
|||
|
|||
// JmesPath is the epresentation of a compiled JMES path query. A JmesPath is
|
|||
// safe for concurrent use by multiple goroutines.
|
|||
type JMESPath struct { |
|||
ast ASTNode |
|||
intr *treeInterpreter |
|||
} |
|||
|
|||
// Compile parses a JMESPath expression and returns, if successful, a JMESPath
|
|||
// object that can be used to match against data.
|
|||
func Compile(expression string) (*JMESPath, error) { |
|||
parser := NewParser() |
|||
ast, err := parser.Parse(expression) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
jmespath := &JMESPath{ast: ast, intr: newInterpreter()} |
|||
return jmespath, nil |
|||
} |
|||
|
|||
// MustCompile is like Compile but panics if the expression cannot be parsed.
|
|||
// It simplifies safe initialization of global variables holding compiled
|
|||
// JMESPaths.
|
|||
func MustCompile(expression string) *JMESPath { |
|||
jmespath, err := Compile(expression) |
|||
if err != nil { |
|||
panic(`jmespath: Compile(` + strconv.Quote(expression) + `): ` + err.Error()) |
|||
} |
|||
return jmespath |
|||
} |
|||
|
|||
// Search evaluates a JMESPath expression against input data and returns the result.
|
|||
func (jp *JMESPath) Search(data interface{}) (interface{}, error) { |
|||
return jp.intr.Execute(jp.ast, data) |
|||
} |
|||
|
|||
// Search evaluates a JMESPath expression against input data and returns the result.
|
|||
func Search(expression string, data interface{}) (interface{}, error) { |
|||
intr := newInterpreter() |
|||
parser := NewParser() |
|||
ast, err := parser.Parse(expression) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return intr.Execute(ast, data) |
|||
} |
@ -0,0 +1,16 @@ |
|||
// generated by stringer -type astNodeType; DO NOT EDIT
|
|||
|
|||
package jmespath |
|||
|
|||
import "fmt" |
|||
|
|||
const _astNodeType_name = "ASTEmptyASTComparatorASTCurrentNodeASTExpRefASTFunctionExpressionASTFieldASTFilterProjectionASTFlattenASTIdentityASTIndexASTIndexExpressionASTKeyValPairASTLiteralASTMultiSelectHashASTMultiSelectListASTOrExpressionASTAndExpressionASTNotExpressionASTPipeASTProjectionASTSubexpressionASTSliceASTValueProjection" |
|||
|
|||
var _astNodeType_index = [...]uint16{0, 8, 21, 35, 44, 65, 73, 92, 102, 113, 121, 139, 152, 162, 180, 198, 213, 229, 245, 252, 265, 281, 289, 307} |
|||
|
|||
func (i astNodeType) String() string { |
|||
if i < 0 || i >= astNodeType(len(_astNodeType_index)-1) { |
|||
return fmt.Sprintf("astNodeType(%d)", i) |
|||
} |
|||
return _astNodeType_name[_astNodeType_index[i]:_astNodeType_index[i+1]] |
|||
} |
@ -0,0 +1,842 @@ |
|||
package jmespath |
|||
|
|||
import ( |
|||
"encoding/json" |
|||
"errors" |
|||
"fmt" |
|||
"math" |
|||
"reflect" |
|||
"sort" |
|||
"strconv" |
|||
"strings" |
|||
"unicode/utf8" |
|||
) |
|||
|
|||
type jpFunction func(arguments []interface{}) (interface{}, error) |
|||
|
|||
type jpType string |
|||
|
|||
const ( |
|||
jpUnknown jpType = "unknown" |
|||
jpNumber jpType = "number" |
|||
jpString jpType = "string" |
|||
jpArray jpType = "array" |
|||
jpObject jpType = "object" |
|||
jpArrayNumber jpType = "array[number]" |
|||
jpArrayString jpType = "array[string]" |
|||
jpExpref jpType = "expref" |
|||
jpAny jpType = "any" |
|||
) |
|||
|
|||
type functionEntry struct { |
|||
name string |
|||
arguments []argSpec |
|||
handler jpFunction |
|||
hasExpRef bool |
|||
} |
|||
|
|||
type argSpec struct { |
|||
types []jpType |
|||
variadic bool |
|||
} |
|||
|
|||
type byExprString struct { |
|||
intr *treeInterpreter |
|||
node ASTNode |
|||
items []interface{} |
|||
hasError bool |
|||
} |
|||
|
|||
func (a *byExprString) Len() int { |
|||
return len(a.items) |
|||
} |
|||
func (a *byExprString) Swap(i, j int) { |
|||
a.items[i], a.items[j] = a.items[j], a.items[i] |
|||
} |
|||
func (a *byExprString) Less(i, j int) bool { |
|||
first, err := a.intr.Execute(a.node, a.items[i]) |
|||
if err != nil { |
|||
a.hasError = true |
|||
// Return a dummy value.
|
|||
return true |
|||
} |
|||
ith, ok := first.(string) |
|||
if !ok { |
|||
a.hasError = true |
|||
return true |
|||
} |
|||
second, err := a.intr.Execute(a.node, a.items[j]) |
|||
if err != nil { |
|||
a.hasError = true |
|||
// Return a dummy value.
|
|||
return true |
|||
} |
|||
jth, ok := second.(string) |
|||
if !ok { |
|||
a.hasError = true |
|||
return true |
|||
} |
|||
return ith < jth |
|||
} |
|||
|
|||
type byExprFloat struct { |
|||
intr *treeInterpreter |
|||
node ASTNode |
|||
items []interface{} |
|||
hasError bool |
|||
} |
|||
|
|||
func (a *byExprFloat) Len() int { |
|||
return len(a.items) |
|||
} |
|||
func (a *byExprFloat) Swap(i, j int) { |
|||
a.items[i], a.items[j] = a.items[j], a.items[i] |
|||
} |
|||
func (a *byExprFloat) Less(i, j int) bool { |
|||
first, err := a.intr.Execute(a.node, a.items[i]) |
|||
if err != nil { |
|||
a.hasError = true |
|||
// Return a dummy value.
|
|||
return true |
|||
} |
|||
ith, ok := first.(float64) |
|||
if !ok { |
|||
a.hasError = true |
|||
return true |
|||
} |
|||
second, err := a.intr.Execute(a.node, a.items[j]) |
|||
if err != nil { |
|||
a.hasError = true |
|||
// Return a dummy value.
|
|||
return true |
|||
} |
|||
jth, ok := second.(float64) |
|||
if !ok { |
|||
a.hasError = true |
|||
return true |
|||
} |
|||
return ith < jth |
|||
} |
|||
|
|||
type functionCaller struct { |
|||
functionTable map[string]functionEntry |
|||
} |
|||
|
|||
func newFunctionCaller() *functionCaller { |
|||
caller := &functionCaller{} |
|||
caller.functionTable = map[string]functionEntry{ |
|||
"length": { |
|||
name: "length", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpString, jpArray, jpObject}}, |
|||
}, |
|||
handler: jpfLength, |
|||
}, |
|||
"starts_with": { |
|||
name: "starts_with", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpString}}, |
|||
{types: []jpType{jpString}}, |
|||
}, |
|||
handler: jpfStartsWith, |
|||
}, |
|||
"abs": { |
|||
name: "abs", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpNumber}}, |
|||
}, |
|||
handler: jpfAbs, |
|||
}, |
|||
"avg": { |
|||
name: "avg", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpArrayNumber}}, |
|||
}, |
|||
handler: jpfAvg, |
|||
}, |
|||
"ceil": { |
|||
name: "ceil", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpNumber}}, |
|||
}, |
|||
handler: jpfCeil, |
|||
}, |
|||
"contains": { |
|||
name: "contains", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpArray, jpString}}, |
|||
{types: []jpType{jpAny}}, |
|||
}, |
|||
handler: jpfContains, |
|||
}, |
|||
"ends_with": { |
|||
name: "ends_with", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpString}}, |
|||
{types: []jpType{jpString}}, |
|||
}, |
|||
handler: jpfEndsWith, |
|||
}, |
|||
"floor": { |
|||
name: "floor", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpNumber}}, |
|||
}, |
|||
handler: jpfFloor, |
|||
}, |
|||
"map": { |
|||
name: "amp", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpExpref}}, |
|||
{types: []jpType{jpArray}}, |
|||
}, |
|||
handler: jpfMap, |
|||
hasExpRef: true, |
|||
}, |
|||
"max": { |
|||
name: "max", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpArrayNumber, jpArrayString}}, |
|||
}, |
|||
handler: jpfMax, |
|||
}, |
|||
"merge": { |
|||
name: "merge", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpObject}, variadic: true}, |
|||
}, |
|||
handler: jpfMerge, |
|||
}, |
|||
"max_by": { |
|||
name: "max_by", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpArray}}, |
|||
{types: []jpType{jpExpref}}, |
|||
}, |
|||
handler: jpfMaxBy, |
|||
hasExpRef: true, |
|||
}, |
|||
"sum": { |
|||
name: "sum", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpArrayNumber}}, |
|||
}, |
|||
handler: jpfSum, |
|||
}, |
|||
"min": { |
|||
name: "min", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpArrayNumber, jpArrayString}}, |
|||
}, |
|||
handler: jpfMin, |
|||
}, |
|||
"min_by": { |
|||
name: "min_by", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpArray}}, |
|||
{types: []jpType{jpExpref}}, |
|||
}, |
|||
handler: jpfMinBy, |
|||
hasExpRef: true, |
|||
}, |
|||
"type": { |
|||
name: "type", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpAny}}, |
|||
}, |
|||
handler: jpfType, |
|||
}, |
|||
"keys": { |
|||
name: "keys", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpObject}}, |
|||
}, |
|||
handler: jpfKeys, |
|||
}, |
|||
"values": { |
|||
name: "values", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpObject}}, |
|||
}, |
|||
handler: jpfValues, |
|||
}, |
|||
"sort": { |
|||
name: "sort", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpArrayString, jpArrayNumber}}, |
|||
}, |
|||
handler: jpfSort, |
|||
}, |
|||
"sort_by": { |
|||
name: "sort_by", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpArray}}, |
|||
{types: []jpType{jpExpref}}, |
|||
}, |
|||
handler: jpfSortBy, |
|||
hasExpRef: true, |
|||
}, |
|||
"join": { |
|||
name: "join", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpString}}, |
|||
{types: []jpType{jpArrayString}}, |
|||
}, |
|||
handler: jpfJoin, |
|||
}, |
|||
"reverse": { |
|||
name: "reverse", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpArray, jpString}}, |
|||
}, |
|||
handler: jpfReverse, |
|||
}, |
|||
"to_array": { |
|||
name: "to_array", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpAny}}, |
|||
}, |
|||
handler: jpfToArray, |
|||
}, |
|||
"to_string": { |
|||
name: "to_string", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpAny}}, |
|||
}, |
|||
handler: jpfToString, |
|||
}, |
|||
"to_number": { |
|||
name: "to_number", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpAny}}, |
|||
}, |
|||
handler: jpfToNumber, |
|||
}, |
|||
"not_null": { |
|||
name: "not_null", |
|||
arguments: []argSpec{ |
|||
{types: []jpType{jpAny}, variadic: true}, |
|||
}, |
|||
handler: jpfNotNull, |
|||
}, |
|||
} |
|||
return caller |
|||
} |
|||
|
|||
func (e *functionEntry) resolveArgs(arguments []interface{}) ([]interface{}, error) { |
|||
if len(e.arguments) == 0 { |
|||
return arguments, nil |
|||
} |
|||
if !e.arguments[len(e.arguments)-1].variadic { |
|||
if len(e.arguments) != len(arguments) { |
|||
return nil, errors.New("incorrect number of args") |
|||
} |
|||
for i, spec := range e.arguments { |
|||
userArg := arguments[i] |
|||
err := spec.typeCheck(userArg) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
} |
|||
return arguments, nil |
|||
} |
|||
if len(arguments) < len(e.arguments) { |
|||
return nil, errors.New("Invalid arity.") |
|||
} |
|||
return arguments, nil |
|||
} |
|||
|
|||
func (a *argSpec) typeCheck(arg interface{}) error { |
|||
for _, t := range a.types { |
|||
switch t { |
|||
case jpNumber: |
|||
if _, ok := arg.(float64); ok { |
|||
return nil |
|||
} |
|||
case jpString: |
|||
if _, ok := arg.(string); ok { |
|||
return nil |
|||
} |
|||
case jpArray: |
|||
if isSliceType(arg) { |
|||
return nil |
|||
} |
|||
case jpObject: |
|||
if _, ok := arg.(map[string]interface{}); ok { |
|||
return nil |
|||
} |
|||
case jpArrayNumber: |
|||
if _, ok := toArrayNum(arg); ok { |
|||
return nil |
|||
} |
|||
case jpArrayString: |
|||
if _, ok := toArrayStr(arg); ok { |
|||
return nil |
|||
} |
|||
case jpAny: |
|||
return nil |
|||
case jpExpref: |
|||
if _, ok := arg.(expRef); ok { |
|||
return nil |
|||
} |
|||
} |
|||
} |
|||
return fmt.Errorf("Invalid type for: %v, expected: %#v", arg, a.types) |
|||
} |
|||
|
|||
func (f *functionCaller) CallFunction(name string, arguments []interface{}, intr *treeInterpreter) (interface{}, error) { |
|||
entry, ok := f.functionTable[name] |
|||
if !ok { |
|||
return nil, errors.New("unknown function: " + name) |
|||
} |
|||
resolvedArgs, err := entry.resolveArgs(arguments) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if entry.hasExpRef { |
|||
var extra []interface{} |
|||
extra = append(extra, intr) |
|||
resolvedArgs = append(extra, resolvedArgs...) |
|||
} |
|||
return entry.handler(resolvedArgs) |
|||
} |
|||
|
|||
func jpfAbs(arguments []interface{}) (interface{}, error) { |
|||
num := arguments[0].(float64) |
|||
return math.Abs(num), nil |
|||
} |
|||
|
|||
func jpfLength(arguments []interface{}) (interface{}, error) { |
|||
arg := arguments[0] |
|||
if c, ok := arg.(string); ok { |
|||
return float64(utf8.RuneCountInString(c)), nil |
|||
} else if isSliceType(arg) { |
|||
v := reflect.ValueOf(arg) |
|||
return float64(v.Len()), nil |
|||
} else if c, ok := arg.(map[string]interface{}); ok { |
|||
return float64(len(c)), nil |
|||
} |
|||
return nil, errors.New("could not compute length()") |
|||
} |
|||
|
|||
func jpfStartsWith(arguments []interface{}) (interface{}, error) { |
|||
search := arguments[0].(string) |
|||
prefix := arguments[1].(string) |
|||
return strings.HasPrefix(search, prefix), nil |
|||
} |
|||
|
|||
func jpfAvg(arguments []interface{}) (interface{}, error) { |
|||
// We've already type checked the value so we can safely use
|
|||
// type assertions.
|
|||
args := arguments[0].([]interface{}) |
|||
length := float64(len(args)) |
|||
numerator := 0.0 |
|||
for _, n := range args { |
|||
numerator += n.(float64) |
|||
} |
|||
return numerator / length, nil |
|||
} |
|||
func jpfCeil(arguments []interface{}) (interface{}, error) { |
|||
val := arguments[0].(float64) |
|||
return math.Ceil(val), nil |
|||
} |
|||
func jpfContains(arguments []interface{}) (interface{}, error) { |
|||
search := arguments[0] |
|||
el := arguments[1] |
|||
if searchStr, ok := search.(string); ok { |
|||
if elStr, ok := el.(string); ok { |
|||
return strings.Index(searchStr, elStr) != -1, nil |
|||
} |
|||
return false, nil |
|||
} |
|||
// Otherwise this is a generic contains for []interface{}
|
|||
general := search.([]interface{}) |
|||
for _, item := range general { |
|||
if item == el { |
|||
return true, nil |
|||
} |
|||
} |
|||
return false, nil |
|||
} |
|||
func jpfEndsWith(arguments []interface{}) (interface{}, error) { |
|||
search := arguments[0].(string) |
|||
suffix := arguments[1].(string) |
|||
return strings.HasSuffix(search, suffix), nil |
|||
} |
|||
func jpfFloor(arguments []interface{}) (interface{}, error) { |
|||
val := arguments[0].(float64) |
|||
return math.Floor(val), nil |
|||
} |
|||
func jpfMap(arguments []interface{}) (interface{}, error) { |
|||
intr := arguments[0].(*treeInterpreter) |
|||
exp := arguments[1].(expRef) |
|||
node := exp.ref |
|||
arr := arguments[2].([]interface{}) |
|||
mapped := make([]interface{}, 0, len(arr)) |
|||
for _, value := range arr { |
|||
current, err := intr.Execute(node, value) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
mapped = append(mapped, current) |
|||
} |
|||
return mapped, nil |
|||
} |
|||
func jpfMax(arguments []interface{}) (interface{}, error) { |
|||
if items, ok := toArrayNum(arguments[0]); ok { |
|||
if len(items) == 0 { |
|||
return nil, nil |
|||
} |
|||
if len(items) == 1 { |
|||
return items[0], nil |
|||
} |
|||
best := items[0] |
|||
for _, item := range items[1:] { |
|||
if item > best { |
|||
best = item |
|||
} |
|||
} |
|||
return best, nil |
|||
} |
|||
// Otherwise we're dealing with a max() of strings.
|
|||
items, _ := toArrayStr(arguments[0]) |
|||
if len(items) == 0 { |
|||
return nil, nil |
|||
} |
|||
if len(items) == 1 { |
|||
return items[0], nil |
|||
} |
|||
best := items[0] |
|||
for _, item := range items[1:] { |
|||
if item > best { |
|||
best = item |
|||
} |
|||
} |
|||
return best, nil |
|||
} |
|||
func jpfMerge(arguments []interface{}) (interface{}, error) { |
|||
final := make(map[string]interface{}) |
|||
for _, m := range arguments { |
|||
mapped := m.(map[string]interface{}) |
|||
for key, value := range mapped { |
|||
final[key] = value |
|||
} |
|||
} |
|||
return final, nil |
|||
} |
|||
func jpfMaxBy(arguments []interface{}) (interface{}, error) { |
|||
intr := arguments[0].(*treeInterpreter) |
|||
arr := arguments[1].([]interface{}) |
|||
exp := arguments[2].(expRef) |
|||
node := exp.ref |
|||
if len(arr) == 0 { |
|||
return nil, nil |
|||
} else if len(arr) == 1 { |
|||
return arr[0], nil |
|||
} |
|||
start, err := intr.Execute(node, arr[0]) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
switch t := start.(type) { |
|||
case float64: |
|||
bestVal := t |
|||
bestItem := arr[0] |
|||
for _, item := range arr[1:] { |
|||
result, err := intr.Execute(node, item) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
current, ok := result.(float64) |
|||
if !ok { |
|||
return nil, errors.New("invalid type, must be number") |
|||
} |
|||
if current > bestVal { |
|||
bestVal = current |
|||
bestItem = item |
|||
} |
|||
} |
|||
return bestItem, nil |
|||
case string: |
|||
bestVal := t |
|||
bestItem := arr[0] |
|||
for _, item := range arr[1:] { |
|||
result, err := intr.Execute(node, item) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
current, ok := result.(string) |
|||
if !ok { |
|||
return nil, errors.New("invalid type, must be string") |
|||
} |
|||
if current > bestVal { |
|||
bestVal = current |
|||
bestItem = item |
|||
} |
|||
} |
|||
return bestItem, nil |
|||
default: |
|||
return nil, errors.New("invalid type, must be number of string") |
|||
} |
|||
} |
|||
func jpfSum(arguments []interface{}) (interface{}, error) { |
|||
items, _ := toArrayNum(arguments[0]) |
|||
sum := 0.0 |
|||
for _, item := range items { |
|||
sum += item |
|||
} |
|||
return sum, nil |
|||
} |
|||
|
|||
func jpfMin(arguments []interface{}) (interface{}, error) { |
|||
if items, ok := toArrayNum(arguments[0]); ok { |
|||
if len(items) == 0 { |
|||
return nil, nil |
|||
} |
|||
if len(items) == 1 { |
|||
return items[0], nil |
|||
} |
|||
best := items[0] |
|||
for _, item := range items[1:] { |
|||
if item < best { |
|||
best = item |
|||
} |
|||
} |
|||
return best, nil |
|||
} |
|||
items, _ := toArrayStr(arguments[0]) |
|||
if len(items) == 0 { |
|||
return nil, nil |
|||
} |
|||
if len(items) == 1 { |
|||
return items[0], nil |
|||
} |
|||
best := items[0] |
|||
for _, item := range items[1:] { |
|||
if item < best { |
|||
best = item |
|||
} |
|||
} |
|||
return best, nil |
|||
} |
|||
|
|||
func jpfMinBy(arguments []interface{}) (interface{}, error) { |
|||
intr := arguments[0].(*treeInterpreter) |
|||
arr := arguments[1].([]interface{}) |
|||
exp := arguments[2].(expRef) |
|||
node := exp.ref |
|||
if len(arr) == 0 { |
|||
return nil, nil |
|||
} else if len(arr) == 1 { |
|||
return arr[0], nil |
|||
} |
|||
start, err := intr.Execute(node, arr[0]) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if t, ok := start.(float64); ok { |
|||
bestVal := t |
|||
bestItem := arr[0] |
|||
for _, item := range arr[1:] { |
|||
result, err := intr.Execute(node, item) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
current, ok := result.(float64) |
|||
if !ok { |
|||
return nil, errors.New("invalid type, must be number") |
|||
} |
|||
if current < bestVal { |
|||
bestVal = current |
|||
bestItem = item |
|||
} |
|||
} |
|||
return bestItem, nil |
|||
} else if t, ok := start.(string); ok { |
|||
bestVal := t |
|||
bestItem := arr[0] |
|||
for _, item := range arr[1:] { |
|||
result, err := intr.Execute(node, item) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
current, ok := result.(string) |
|||
if !ok { |
|||
return nil, errors.New("invalid type, must be string") |
|||
} |
|||
if current < bestVal { |
|||
bestVal = current |
|||
bestItem = item |
|||
} |
|||
} |
|||
return bestItem, nil |
|||
} else { |
|||
return nil, errors.New("invalid type, must be number of string") |
|||
} |
|||
} |
|||
func jpfType(arguments []interface{}) (interface{}, error) { |
|||
arg := arguments[0] |
|||
if _, ok := arg.(float64); ok { |
|||
return "number", nil |
|||
} |
|||
if _, ok := arg.(string); ok { |
|||
return "string", nil |
|||
} |
|||
if _, ok := arg.([]interface{}); ok { |
|||
return "array", nil |
|||
} |
|||
if _, ok := arg.(map[string]interface{}); ok { |
|||
return "object", nil |
|||
} |
|||
if arg == nil { |
|||
return "null", nil |
|||
} |
|||
if arg == true || arg == false { |
|||
return "boolean", nil |
|||
} |
|||
return nil, errors.New("unknown type") |
|||
} |
|||
func jpfKeys(arguments []interface{}) (interface{}, error) { |
|||
arg := arguments[0].(map[string]interface{}) |
|||
collected := make([]interface{}, 0, len(arg)) |
|||
for key := range arg { |
|||
collected = append(collected, key) |
|||
} |
|||
return collected, nil |
|||
} |
|||
func jpfValues(arguments []interface{}) (interface{}, error) { |
|||
arg := arguments[0].(map[string]interface{}) |
|||
collected := make([]interface{}, 0, len(arg)) |
|||
for _, value := range arg { |
|||
collected = append(collected, value) |
|||
} |
|||
return collected, nil |
|||
} |
|||
func jpfSort(arguments []interface{}) (interface{}, error) { |
|||
if items, ok := toArrayNum(arguments[0]); ok { |
|||
d := sort.Float64Slice(items) |
|||
sort.Stable(d) |
|||
final := make([]interface{}, len(d)) |
|||
for i, val := range d { |
|||
final[i] = val |
|||
} |
|||
return final, nil |
|||
} |
|||
// Otherwise we're dealing with sort()'ing strings.
|
|||
items, _ := toArrayStr(arguments[0]) |
|||
d := sort.StringSlice(items) |
|||
sort.Stable(d) |
|||
final := make([]interface{}, len(d)) |
|||
for i, val := range d { |
|||
final[i] = val |
|||
} |
|||
return final, nil |
|||
} |
|||
func jpfSortBy(arguments []interface{}) (interface{}, error) { |
|||
intr := arguments[0].(*treeInterpreter) |
|||
arr := arguments[1].([]interface{}) |
|||
exp := arguments[2].(expRef) |
|||
node := exp.ref |
|||
if len(arr) == 0 { |
|||
return arr, nil |
|||
} else if len(arr) == 1 { |
|||
return arr, nil |
|||
} |
|||
start, err := intr.Execute(node, arr[0]) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if _, ok := start.(float64); ok { |
|||
sortable := &byExprFloat{intr, node, arr, false} |
|||
sort.Stable(sortable) |
|||
if sortable.hasError { |
|||
return nil, errors.New("error in sort_by comparison") |
|||
} |
|||
return arr, nil |
|||
} else if _, ok := start.(string); ok { |
|||
sortable := &byExprString{intr, node, arr, false} |
|||
sort.Stable(sortable) |
|||
if sortable.hasError { |
|||
return nil, errors.New("error in sort_by comparison") |
|||
} |
|||
return arr, nil |
|||
} else { |
|||
return nil, errors.New("invalid type, must be number of string") |
|||
} |
|||
} |
|||
func jpfJoin(arguments []interface{}) (interface{}, error) { |
|||
sep := arguments[0].(string) |
|||
// We can't just do arguments[1].([]string), we have to
|
|||
// manually convert each item to a string.
|
|||
arrayStr := []string{} |
|||
for _, item := range arguments[1].([]interface{}) { |
|||
arrayStr = append(arrayStr, item.(string)) |
|||
} |
|||
return strings.Join(arrayStr, sep), nil |
|||
} |
|||
func jpfReverse(arguments []interface{}) (interface{}, error) { |
|||
if s, ok := arguments[0].(string); ok { |
|||
r := []rune(s) |
|||
for i, j := 0, len(r)-1; i < len(r)/2; i, j = i+1, j-1 { |
|||
r[i], r[j] = r[j], r[i] |
|||
} |
|||
return string(r), nil |
|||
} |
|||
items := arguments[0].([]interface{}) |
|||
length := len(items) |
|||
reversed := make([]interface{}, length) |
|||
for i, item := range items { |
|||
reversed[length-(i+1)] = item |
|||
} |
|||
return reversed, nil |
|||
} |
|||
func jpfToArray(arguments []interface{}) (interface{}, error) { |
|||
if _, ok := arguments[0].([]interface{}); ok { |
|||
return arguments[0], nil |
|||
} |
|||
return arguments[:1:1], nil |
|||
} |
|||
func jpfToString(arguments []interface{}) (interface{}, error) { |
|||
if v, ok := arguments[0].(string); ok { |
|||
return v, nil |
|||
} |
|||
result, err := json.Marshal(arguments[0]) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return string(result), nil |
|||
} |
|||
func jpfToNumber(arguments []interface{}) (interface{}, error) { |
|||
arg := arguments[0] |
|||
if v, ok := arg.(float64); ok { |
|||
return v, nil |
|||
} |
|||
if v, ok := arg.(string); ok { |
|||
conv, err := strconv.ParseFloat(v, 64) |
|||
if err != nil { |
|||
return nil, nil |
|||
} |
|||
return conv, nil |
|||
} |
|||
if _, ok := arg.([]interface{}); ok { |
|||
return nil, nil |
|||
} |
|||
if _, ok := arg.(map[string]interface{}); ok { |
|||
return nil, nil |
|||
} |
|||
if arg == nil { |
|||
return nil, nil |
|||
} |
|||
if arg == true || arg == false { |
|||
return nil, nil |
|||
} |
|||
return nil, errors.New("unknown type") |
|||
} |
|||
func jpfNotNull(arguments []interface{}) (interface{}, error) { |
|||
for _, arg := range arguments { |
|||
if arg != nil { |
|||
return arg, nil |
|||
} |
|||
} |
|||
return nil, nil |
|||
} |
@ -0,0 +1,418 @@ |
|||
package jmespath |
|||
|
|||
import ( |
|||
"errors" |
|||
"reflect" |
|||
"unicode" |
|||
"unicode/utf8" |
|||
) |
|||
|
|||
/* This is a tree based interpreter. It walks the AST and directly |
|||
interprets the AST to search through a JSON document. |
|||
*/ |
|||
|
|||
type treeInterpreter struct { |
|||
fCall *functionCaller |
|||
} |
|||
|
|||
func newInterpreter() *treeInterpreter { |
|||
interpreter := treeInterpreter{} |
|||
interpreter.fCall = newFunctionCaller() |
|||
return &interpreter |
|||
} |
|||
|
|||
type expRef struct { |
|||
ref ASTNode |
|||
} |
|||
|
|||
// Execute takes an ASTNode and input data and interprets the AST directly.
|
|||
// It will produce the result of applying the JMESPath expression associated
|
|||
// with the ASTNode to the input data "value".
|
|||
func (intr *treeInterpreter) Execute(node ASTNode, value interface{}) (interface{}, error) { |
|||
switch node.nodeType { |
|||
case ASTComparator: |
|||
left, err := intr.Execute(node.children[0], value) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
right, err := intr.Execute(node.children[1], value) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
switch node.value { |
|||
case tEQ: |
|||
return objsEqual(left, right), nil |
|||
case tNE: |
|||
return !objsEqual(left, right), nil |
|||
} |
|||
leftNum, ok := left.(float64) |
|||
if !ok { |
|||
return nil, nil |
|||
} |
|||
rightNum, ok := right.(float64) |
|||
if !ok { |
|||
return nil, nil |
|||
} |
|||
switch node.value { |
|||
case tGT: |
|||
return leftNum > rightNum, nil |
|||
case tGTE: |
|||
return leftNum >= rightNum, nil |
|||
case tLT: |
|||
return leftNum < rightNum, nil |
|||
case tLTE: |
|||
return leftNum <= rightNum, nil |
|||
} |
|||
case ASTExpRef: |
|||
return expRef{ref: node.children[0]}, nil |
|||
case ASTFunctionExpression: |
|||
resolvedArgs := []interface{}{} |
|||
for _, arg := range node.children { |
|||
current, err := intr.Execute(arg, value) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
resolvedArgs = append(resolvedArgs, current) |
|||
} |
|||
return intr.fCall.CallFunction(node.value.(string), resolvedArgs, intr) |
|||
case ASTField: |
|||
if m, ok := value.(map[string]interface{}); ok { |
|||
key := node.value.(string) |
|||
return m[key], nil |
|||
} |
|||
return intr.fieldFromStruct(node.value.(string), value) |
|||
case ASTFilterProjection: |
|||
left, err := intr.Execute(node.children[0], value) |
|||
if err != nil { |
|||
return nil, nil |
|||
} |
|||
sliceType, ok := left.([]interface{}) |
|||
if !ok { |
|||
if isSliceType(left) { |
|||
return intr.filterProjectionWithReflection(node, left) |
|||
} |
|||
return nil, nil |
|||
} |
|||
compareNode := node.children[2] |
|||
collected := []interface{}{} |
|||
for _, element := range sliceType { |
|||
result, err := intr.Execute(compareNode, element) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if !isFalse(result) { |
|||
current, err := intr.Execute(node.children[1], element) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if current != nil { |
|||
collected = append(collected, current) |
|||
} |
|||
} |
|||
} |
|||
return collected, nil |
|||
case ASTFlatten: |
|||
left, err := intr.Execute(node.children[0], value) |
|||
if err != nil { |
|||
return nil, nil |
|||
} |
|||
sliceType, ok := left.([]interface{}) |
|||
if !ok { |
|||
// If we can't type convert to []interface{}, there's
|
|||
// a chance this could still work via reflection if we're
|
|||
// dealing with user provided types.
|
|||
if isSliceType(left) { |
|||
return intr.flattenWithReflection(left) |
|||
} |
|||
return nil, nil |
|||
} |
|||
flattened := []interface{}{} |
|||
for _, element := range sliceType { |
|||
if elementSlice, ok := element.([]interface{}); ok { |
|||
flattened = append(flattened, elementSlice...) |
|||
} else if isSliceType(element) { |
|||
reflectFlat := []interface{}{} |
|||
v := reflect.ValueOf(element) |
|||
for i := 0; i < v.Len(); i++ { |
|||
reflectFlat = append(reflectFlat, v.Index(i).Interface()) |
|||
} |
|||
flattened = append(flattened, reflectFlat...) |
|||
} else { |
|||
flattened = append(flattened, element) |
|||
} |
|||
} |
|||
return flattened, nil |
|||
case ASTIdentity, ASTCurrentNode: |
|||
return value, nil |
|||
case ASTIndex: |
|||
if sliceType, ok := value.([]interface{}); ok { |
|||
index := node.value.(int) |
|||
if index < 0 { |
|||
index += len(sliceType) |
|||
} |
|||
if index < len(sliceType) && index >= 0 { |
|||
return sliceType[index], nil |
|||
} |
|||
return nil, nil |
|||
} |
|||
// Otherwise try via reflection.
|
|||
rv := reflect.ValueOf(value) |
|||
if rv.Kind() == reflect.Slice { |
|||
index := node.value.(int) |
|||
if index < 0 { |
|||
index += rv.Len() |
|||
} |
|||
if index < rv.Len() && index >= 0 { |
|||
v := rv.Index(index) |
|||
return v.Interface(), nil |
|||
} |
|||
} |
|||
return nil, nil |
|||
case ASTKeyValPair: |
|||
return intr.Execute(node.children[0], value) |
|||
case ASTLiteral: |
|||
return node.value, nil |
|||
case ASTMultiSelectHash: |
|||
if value == nil { |
|||
return nil, nil |
|||
} |
|||
collected := make(map[string]interface{}) |
|||
for _, child := range node.children { |
|||
current, err := intr.Execute(child, value) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
key := child.value.(string) |
|||
collected[key] = current |
|||
} |
|||
return collected, nil |
|||
case ASTMultiSelectList: |
|||
if value == nil { |
|||
return nil, nil |
|||
} |
|||
collected := []interface{}{} |
|||
for _, child := range node.children { |
|||
current, err := intr.Execute(child, value) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
collected = append(collected, current) |
|||
} |
|||
return collected, nil |
|||
case ASTOrExpression: |
|||
matched, err := intr.Execute(node.children[0], value) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if isFalse(matched) { |
|||
matched, err = intr.Execute(node.children[1], value) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
} |
|||
return matched, nil |
|||
case ASTAndExpression: |
|||
matched, err := intr.Execute(node.children[0], value) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if isFalse(matched) { |
|||
return matched, nil |
|||
} |
|||
return intr.Execute(node.children[1], value) |
|||
case ASTNotExpression: |
|||
matched, err := intr.Execute(node.children[0], value) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if isFalse(matched) { |
|||
return true, nil |
|||
} |
|||
return false, nil |
|||
case ASTPipe: |
|||
result := value |
|||
var err error |
|||
for _, child := range node.children { |
|||
result, err = intr.Execute(child, result) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
} |
|||
return result, nil |
|||
case ASTProjection: |
|||
left, err := intr.Execute(node.children[0], value) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
sliceType, ok := left.([]interface{}) |
|||
if !ok { |
|||
if isSliceType(left) { |
|||
return intr.projectWithReflection(node, left) |
|||
} |
|||
return nil, nil |
|||
} |
|||
collected := []interface{}{} |
|||
var current interface{} |
|||
for _, element := range sliceType { |
|||
current, err = intr.Execute(node.children[1], element) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if current != nil { |
|||
collected = append(collected, current) |
|||
} |
|||
} |
|||
return collected, nil |
|||
case ASTSubexpression, ASTIndexExpression: |
|||
left, err := intr.Execute(node.children[0], value) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return intr.Execute(node.children[1], left) |
|||
case ASTSlice: |
|||
sliceType, ok := value.([]interface{}) |
|||
if !ok { |
|||
if isSliceType(value) { |
|||
return intr.sliceWithReflection(node, value) |
|||
} |
|||
return nil, nil |
|||
} |
|||
parts := node.value.([]*int) |
|||
sliceParams := make([]sliceParam, 3) |
|||
for i, part := range parts { |
|||
if part != nil { |
|||
sliceParams[i].Specified = true |
|||
sliceParams[i].N = *part |
|||
} |
|||
} |
|||
return slice(sliceType, sliceParams) |
|||
case ASTValueProjection: |
|||
left, err := intr.Execute(node.children[0], value) |
|||
if err != nil { |
|||
return nil, nil |
|||
} |
|||
mapType, ok := left.(map[string]interface{}) |
|||
if !ok { |
|||
return nil, nil |
|||
} |
|||
values := make([]interface{}, len(mapType)) |
|||
for _, value := range mapType { |
|||
values = append(values, value) |
|||
} |
|||
collected := []interface{}{} |
|||
for _, element := range values { |
|||
current, err := intr.Execute(node.children[1], element) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if current != nil { |
|||
collected = append(collected, current) |
|||
} |
|||
} |
|||
return collected, nil |
|||
} |
|||
return nil, errors.New("Unknown AST node: " + node.nodeType.String()) |
|||
} |
|||
|
|||
func (intr *treeInterpreter) fieldFromStruct(key string, value interface{}) (interface{}, error) { |
|||
rv := reflect.ValueOf(value) |
|||
first, n := utf8.DecodeRuneInString(key) |
|||
fieldName := string(unicode.ToUpper(first)) + key[n:] |
|||
if rv.Kind() == reflect.Struct { |
|||
v := rv.FieldByName(fieldName) |
|||
if !v.IsValid() { |
|||
return nil, nil |
|||
} |
|||
return v.Interface(), nil |
|||
} else if rv.Kind() == reflect.Ptr { |
|||
// Handle multiple levels of indirection?
|
|||
if rv.IsNil() { |
|||
return nil, nil |
|||
} |
|||
rv = rv.Elem() |
|||
v := rv.FieldByName(fieldName) |
|||
if !v.IsValid() { |
|||
return nil, nil |
|||
} |
|||
return v.Interface(), nil |
|||
} |
|||
return nil, nil |
|||
} |
|||
|
|||
func (intr *treeInterpreter) flattenWithReflection(value interface{}) (interface{}, error) { |
|||
v := reflect.ValueOf(value) |
|||
flattened := []interface{}{} |
|||
for i := 0; i < v.Len(); i++ { |
|||
element := v.Index(i).Interface() |
|||
if reflect.TypeOf(element).Kind() == reflect.Slice { |
|||
// Then insert the contents of the element
|
|||
// slice into the flattened slice,
|
|||
// i.e flattened = append(flattened, mySlice...)
|
|||
elementV := reflect.ValueOf(element) |
|||
for j := 0; j < elementV.Len(); j++ { |
|||
flattened = append( |
|||
flattened, elementV.Index(j).Interface()) |
|||
} |
|||
} else { |
|||
flattened = append(flattened, element) |
|||
} |
|||
} |
|||
return flattened, nil |
|||
} |
|||
|
|||
func (intr *treeInterpreter) sliceWithReflection(node ASTNode, value interface{}) (interface{}, error) { |
|||
v := reflect.ValueOf(value) |
|||
parts := node.value.([]*int) |
|||
sliceParams := make([]sliceParam, 3) |
|||
for i, part := range parts { |
|||
if part != nil { |
|||
sliceParams[i].Specified = true |
|||
sliceParams[i].N = *part |
|||
} |
|||
} |
|||
final := []interface{}{} |
|||
for i := 0; i < v.Len(); i++ { |
|||
element := v.Index(i).Interface() |
|||
final = append(final, element) |
|||
} |
|||
return slice(final, sliceParams) |
|||
} |
|||
|
|||
func (intr *treeInterpreter) filterProjectionWithReflection(node ASTNode, value interface{}) (interface{}, error) { |
|||
compareNode := node.children[2] |
|||
collected := []interface{}{} |
|||
v := reflect.ValueOf(value) |
|||
for i := 0; i < v.Len(); i++ { |
|||
element := v.Index(i).Interface() |
|||
result, err := intr.Execute(compareNode, element) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if !isFalse(result) { |
|||
current, err := intr.Execute(node.children[1], element) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if current != nil { |
|||
collected = append(collected, current) |
|||
} |
|||
} |
|||
} |
|||
return collected, nil |
|||
} |
|||
|
|||
func (intr *treeInterpreter) projectWithReflection(node ASTNode, value interface{}) (interface{}, error) { |
|||
collected := []interface{}{} |
|||
v := reflect.ValueOf(value) |
|||
for i := 0; i < v.Len(); i++ { |
|||
element := v.Index(i).Interface() |
|||
result, err := intr.Execute(node.children[1], element) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if result != nil { |
|||
collected = append(collected, result) |
|||
} |
|||
} |
|||
return collected, nil |
|||
} |
Some files were not shown because too many files changed in this diff
Loading…
Reference in new issue