77package unified
88
99import (
10+ "bytes"
1011 "context"
1112 "encoding/base64"
1213 "fmt"
@@ -41,11 +42,18 @@ var (
4142 "createUser" , "updateUser" , "copydbgetnonce" , "copydbsaslstart" , "copydb" ,
4243 }
4344
44- awsAccessKeyID = os .Getenv ("FLE_AWS_KEY" )
45- awsSecretAccessKey = os .Getenv ("FLE_AWS_SECRET" )
46- azureTenantID = os .Getenv ("FLE_AZURE_TENANTID" )
47- azureClientID = os .Getenv ("FLE_AZURE_CLIENTID" )
48- azureClientSecret = os .Getenv ("FLE_AZURE_CLIENTSECRET" )
45+ awsAccessKeyID = os .Getenv ("FLE_AWS_KEY" )
46+ awsSecretAccessKey = os .Getenv ("FLE_AWS_SECRET" )
47+ awsTempAccessKeyID = os .Getenv ("CSFLE_AWS_TEMP_ACCESS_KEY_ID" )
48+ awsTempSecretAccessKey = os .Getenv ("CSFLE_AWS_TEMP_SECRET_ACCESS_KEY" )
49+ awsTempSessionToken = os .Getenv ("CSFLE_AWS_TEMP_SESSION_TOKEN" )
50+ azureTenantID = os .Getenv ("FLE_AZURE_TENANTID" )
51+ azureClientID = os .Getenv ("FLE_AZURE_CLIENTID" )
52+ azureClientSecret = os .Getenv ("FLE_AZURE_CLIENTSECRET" )
53+ gcpEmail = os .Getenv ("FLE_GCP_EMAIL" )
54+ gcpPrivateKey = os .Getenv ("FLE_GCP_PRIVATEKEY" )
55+
56+ placeholderDoc = bsoncore .NewDocumentBuilder ().AppendInt32 ("$$placeholder" , 1 ).Build ()
4957)
5058
5159// clientEntity is a wrapper for a mongo.Client object that also holds additional information required during test
@@ -287,33 +295,112 @@ func createAutoEncryptionOptions(opts bson.Raw) (*options.AutoEncryptionOptions,
287295 if err != nil {
288296 return nil , err
289297 }
298+ retrieveProviderData := func (doc bson.Raw , key , defaultVal string ) (any , error ) {
299+ e := doc .Lookup (key )
300+ if e .IsZero () {
301+ return nil , nil
302+ }
303+ switch e .Type {
304+ case bson .TypeString :
305+ return e .StringValue (), nil
306+ case bson .TypeEmbeddedDocument :
307+ if bytes .Equal (e .Document (), placeholderDoc ) {
308+ return defaultVal , nil
309+ }
310+ }
311+ return nil , fmt .Errorf ("unexpected %s in kms provider: %v" , key , e )
312+ }
290313 for _ , elem := range elems {
291- provider := elem .Key ()
292- providerOpt := elem .Value ()
293- switch provider {
314+ provider := make (map [string ]any )
315+ providerT := elem .Key ()
316+ providerOpt := elem .Value ().Document ()
317+ switch providerT {
294318 case "aws" :
295- providers ["aws" ] = map [string ]any {
296- "accessKeyId" : awsAccessKeyID ,
297- "secretAccessKey" : awsSecretAccessKey ,
319+ accessKeyID := awsAccessKeyID
320+ secretAccessKey := awsSecretAccessKey
321+
322+ // replace with temporary access, if sessionToken placeholder exists
323+ v , err := retrieveProviderData (providerOpt , "sessionToken" , "$$placeholder" )
324+ if err != nil {
325+ return nil , err
326+ }
327+ if v == "$$placeholder" {
328+ provider ["sessionToken" ] = awsTempSessionToken
329+ accessKeyID = awsTempAccessKeyID
330+ secretAccessKey = awsTempSecretAccessKey
331+ } else if v != nil {
332+ provider ["sessionToken" ] = v
333+ }
334+
335+ for _ , e := range []struct {
336+ key string
337+ defaultVal string
338+ }{
339+ {"accessKeyId" , accessKeyID },
340+ {"secretAccessKey" , secretAccessKey },
341+ } {
342+ v , err = retrieveProviderData (providerOpt , e .key , e .defaultVal )
343+ if err != nil {
344+ return nil , err
345+ }
346+ if v != nil {
347+ provider [e .key ] = v
348+ }
298349 }
299350 case "azure" :
300- providers ["azure" ] = map [string ]any {
301- "tenantId" : azureTenantID ,
302- "clientId" : azureClientID ,
303- "clientSecret" : azureClientSecret ,
351+ for _ , e := range []struct {
352+ key string
353+ defaultVal string
354+ }{
355+ {"tenantId" , azureTenantID },
356+ {"clientId" , azureClientID },
357+ {"clientSecret" , azureClientSecret },
358+ } {
359+ v , err := retrieveProviderData (providerOpt , e .key , e .defaultVal )
360+ if err != nil {
361+ return nil , err
362+ }
363+ if v != nil {
364+ provider [e .key ] = v
365+ }
304366 }
305- case "local" :
306- str := providerOpt .Document ().Lookup ("key" ).StringValue ()
307- key , err := base64 .StdEncoding .DecodeString (str )
367+ case "gcp" :
368+ for _ , e := range []struct {
369+ key string
370+ defaultVal string
371+ }{
372+ {"email" , gcpEmail },
373+ {"privateKey" , gcpPrivateKey },
374+ } {
375+ v , err := retrieveProviderData (providerOpt , e .key , e .defaultVal )
376+ if err != nil {
377+ return nil , err
378+ }
379+ if v != nil {
380+ provider [e .key ] = v
381+ }
382+ }
383+ case "kmip" :
384+ v , err := retrieveProviderData (providerOpt , "endpoint" , "localhost:5698" )
308385 if err != nil {
309386 return nil , err
310387 }
311- providers [ "local" ] = map [ string ] any {
312- "key" : key ,
388+ if v != nil {
389+ provider [ "endpoint" ] = v
313390 }
391+ case "local" , "local:name2" :
392+ str := providerOpt .Lookup ("key" ).StringValue ()
393+ key , err := base64 .StdEncoding .DecodeString (str )
394+ if err != nil {
395+ return nil , err
396+ }
397+ provider ["key" ] = key
314398 default :
315399 return nil , fmt .Errorf ("unrecognized KMS provider: %v" , provider )
316400 }
401+ if len (provider ) > 0 {
402+ providers [providerT ] = provider
403+ }
317404 }
318405 aeo .SetKmsProviders (providers )
319406 case "schemaMap" :
@@ -328,6 +415,8 @@ func createAutoEncryptionOptions(opts bson.Raw) (*options.AutoEncryptionOptions,
328415 aeo .SetKeyVaultNamespace (opt .StringValue ())
329416 case "bypassQueryAnalysis" :
330417 aeo .SetBypassQueryAnalysis (opt .Boolean ())
418+ case "bypassAutoEncryption" :
419+ aeo .SetBypassAutoEncryption (opt .Boolean ())
331420 default :
332421 return nil , fmt .Errorf ("unrecognized option: %v" , name )
333422 }
0 commit comments