# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import abc
import logging
import warnings
import xml.etree.ElementTree as ET
from collections import OrderedDict
import re
import csv
import tempfile
import luigi
from luigi import Task
logger = logging.getLogger('luigi-interface')
try:
import requests
except ImportError:
logger.warning("This module requires the python package 'requests'.")
try:
from urlparse import urlsplit
except ImportError:
from urllib.parse import urlsplit
[docs]def get_soql_fields(soql):
"""
Gets queried columns names.
"""
soql_fields = re.search('(?<=select)(?s)(.*)(?=from)', soql, re.IGNORECASE) # get fields
soql_fields = re.sub(' ', '', soql_fields.group()) # remove extra spaces
soql_fields = re.sub('\t', '', soql_fields) # remove tabs
fields = re.split(',|\n|\r|', soql_fields) # split on commas and newlines
fields = [field for field in fields if field != ''] # remove empty strings
return fields
[docs]def ensure_utf(value):
return value.encode("utf-8") if isinstance(value, unicode) else value
[docs]def parse_results(fields, data):
"""
Traverses ordered dictionary, calls _traverse_results() to recursively read into the dictionary depth of data
"""
master = []
for record in data['records']: # for each 'record' in response
row = [None] * len(fields) # create null list the length of number of columns
for obj, value in record.iteritems(): # for each obj in record
if not isinstance(value, (dict, list, tuple)): # if not data structure
if obj in fields:
row[fields.index(obj)] = ensure_utf(value)
elif isinstance(value, dict) and obj != 'attributes': # traverse down into object
path = obj
_traverse_results(value, fields, row, path)
master.append(row)
return master
def _traverse_results(value, fields, row, path):
"""
Helper method for parse_results().
Traverses through ordered dict and recursively calls itself when encountering a dictionary
"""
for f, v in value.iteritems(): # for each item in obj
field_name = '{path}.{name}'.format(path=path, name=f) if path else f
if not isinstance(v, (dict, list, tuple)): # if not data structure
if field_name in fields:
row[fields.index(field_name)] = ensure_utf(v)
elif isinstance(v, dict) and f != 'attributes': # it is a dict
_traverse_results(v, fields, row, field_name)
[docs]class salesforce(luigi.Config):
"""
Config system to get config vars from 'salesforce' section in configuration file.
Did not include sandbox_name here, as the user may have multiple sandboxes.
"""
username = luigi.Parameter(default='')
password = luigi.Parameter(default='')
security_token = luigi.Parameter(default='')
# sandbox token
sb_security_token = luigi.Parameter(default='')
[docs]class QuerySalesforce(Task):
@abc.abstractproperty
def object_name(self):
"""
Override to return the SF object we are querying.
Must have the SF "__c" suffix if it is a customer object.
"""
return None
@property
def use_sandbox(self):
"""
Override to specify use of SF sandbox.
True iff we should be uploading to a sandbox environment instead of the production organization.
"""
return False
@property
def sandbox_name(self):
"""Override to specify the sandbox name if it is intended to be used."""
return None
@abc.abstractproperty
def soql(self):
"""Override to return the raw string SOQL or the path to it."""
return None
@property
def is_soql_file(self):
"""Override to True if soql property is a file path."""
return False
@property
def content_type(self):
"""
Override to use a different content type. Salesforce allows XML, CSV, ZIP_CSV, or ZIP_XML. Defaults to CSV.
"""
return "CSV"
[docs] def run(self):
if self.use_sandbox and not self.sandbox_name:
raise Exception("Parameter sf_sandbox_name must be provided when uploading to a Salesforce Sandbox")
sf = SalesforceAPI(salesforce().username,
salesforce().password,
salesforce().security_token,
salesforce().sb_security_token,
self.sandbox_name)
job_id = sf.create_operation_job('query', self.object_name, content_type=self.content_type)
logger.info("Started query job %s in salesforce for object %s" % (job_id, self.object_name))
batch_id = ''
msg = ''
try:
if self.is_soql_file:
with open(self.soql, 'r') as infile:
self.soql = infile.read()
batch_id = sf.create_batch(job_id, self.soql, self.content_type)
logger.info("Creating new batch %s to query: %s for job: %s." % (batch_id, self.object_name, job_id))
status = sf.block_on_batch(job_id, batch_id)
if status['state'].lower() == 'failed':
msg = "Batch failed with message: %s" % status['state_message']
logger.error(msg)
# don't raise exception if it's b/c of an included relationship
# normal query will execute (with relationship) after bulk job is closed
if 'foreign key relationships not supported' not in status['state_message'].lower():
raise Exception(msg)
else:
result_ids = sf.get_batch_result_ids(job_id, batch_id)
# If there's only one result, just download it, otherwise we need to merge the resulting downloads
if len(result_ids) == 1:
data = sf.get_batch_result(job_id, batch_id, result_ids[0])
with open(self.output().path, 'wb') as outfile:
outfile.write(data)
else:
# Download each file to disk, and then merge into one.
# Preferring to do it this way so as to minimize memory consumption.
for i, result_id in enumerate(result_ids):
logger.info("Downloading batch result %s for batch: %s and job: %s" % (result_id, batch_id, job_id))
with open("%s.%d" % (self.output().path, i), 'wb') as outfile:
outfile.write(sf.get_batch_result(job_id, batch_id, result_id))
logger.info("Merging results of batch %s" % batch_id)
self.merge_batch_results(result_ids)
finally:
logger.info("Closing job %s" % job_id)
sf.close_job(job_id)
if 'state_message' in status and 'foreign key relationships not supported' in status['state_message'].lower():
logger.info("Retrying with REST API query")
data_file = sf.query_all(self.soql)
reader = csv.reader(data_file)
with open(self.output().path, 'wb') as outfile:
writer = csv.writer(outfile, dialect='excel')
for row in reader:
writer.writerow(row)
[docs] def merge_batch_results(self, result_ids):
"""
Merges the resulting files of a multi-result batch bulk query.
"""
outfile = open(self.output().path, 'w')
if self.content_type.lower() == 'csv':
for i, result_id in enumerate(result_ids):
with open("%s.%d" % (self.output().path, i), 'r') as f:
header = f.readline()
if i == 0:
outfile.write(header)
for line in f:
outfile.write(line)
else:
raise Exception("Batch result merging not implemented for %s" % self.content_type)
outfile.close()
[docs]class SalesforceAPI(object):
"""
Class used to interact with the SalesforceAPI. Currently provides only the
methods necessary for performing a bulk upload operation.
"""
API_VERSION = 34.0
SOAP_NS = "{urn:partner.soap.sforce.com}"
API_NS = "{http://www.force.com/2009/06/asyncapi/dataload}"
def __init__(self, username, password, security_token, sb_token=None, sandbox_name=None):
self.username = username
self.password = password
self.security_token = security_token
self.sb_security_token = sb_token
self.sandbox_name = sandbox_name
if self.sandbox_name:
self.username += ".%s" % self.sandbox_name
self.session_id = None
self.server_url = None
self.hostname = None
[docs] def start_session(self):
"""
Starts a Salesforce session and determines which SF instance to use for future requests.
"""
if self.has_active_session():
raise Exception("Session already in progress.")
response = requests.post(self._get_login_url(),
headers=self._get_login_headers(),
data=self._get_login_xml())
response.raise_for_status()
root = ET.fromstring(response.text)
for e in root.iter("%ssessionId" % self.SOAP_NS):
if self.session_id:
raise Exception("Invalid login attempt. Multiple session ids found.")
self.session_id = e.text
for e in root.iter("%sserverUrl" % self.SOAP_NS):
if self.server_url:
raise Exception("Invalid login attempt. Multiple server urls found.")
self.server_url = e.text
if not self.has_active_session():
raise Exception("Invalid login attempt resulted in null sessionId [%s] and/or serverUrl [%s]." %
(self.session_id, self.server_url))
self.hostname = urlsplit(self.server_url).hostname
[docs] def has_active_session(self):
return self.session_id and self.server_url
[docs] def query(self, query, **kwargs):
"""
Return the result of a Salesforce SOQL query as a dict decoded from the Salesforce response JSON payload.
:param query: the SOQL query to send to Salesforce, e.g. "SELECT id from Lead WHERE email = 'a@b.com'"
"""
params = {'q': query}
response = requests.get(self._get_norm_query_url(),
headers=self._get_rest_headers(),
params=params,
**kwargs)
if response.status_code != requests.codes.ok:
raise Exception(response.content)
return response.json()
[docs] def query_more(self, next_records_identifier, identifier_is_url=False, **kwargs):
"""
Retrieves more results from a query that returned more results
than the batch maximum. Returns a dict decoded from the Salesforce
response JSON payload.
:param next_records_identifier: either the Id of the next Salesforce
object in the result, or a URL to the
next record in the result.
:param identifier_is_url: True if `next_records_identifier` should be
treated as a URL, False if
`next_records_identifer` should be treated as
an Id.
"""
if identifier_is_url:
# Don't use `self.base_url` here because the full URI is provided
url = (u'https://{instance}{next_record_url}'
.format(instance=self.hostname,
next_record_url=next_records_identifier))
else:
url = self._get_norm_query_url() + '{next_record_id}'
url = url.format(next_record_id=next_records_identifier)
response = requests.get(url, headers=self._get_rest_headers(), **kwargs)
response.raise_for_status()
return response.json()
[docs] def query_all(self, query, **kwargs):
"""
Returns the full set of results for the `query`. This is a
convenience wrapper around `query(...)` and `query_more(...)`.
The returned dict is the decoded JSON payload from the final call to
Salesforce, but with the `totalSize` field representing the full
number of results retrieved and the `records` list representing the
full list of records retrieved.
:param query: the SOQL query to send to Salesforce, e.g.
`SELECT Id FROM Lead WHERE Email = "waldo@somewhere.com"`
"""
# Make the initial query to Salesforce
response = self.query(query, **kwargs)
# get fields
fields = get_soql_fields(query)
# put fields and first page of results into a temp list to be written to TempFile
tmp_list = [fields]
tmp_list.extend(parse_results(fields, response))
tmp_dir = luigi.configuration.get_config().get('salesforce', 'local-tmp-dir', None)
tmp_file = tempfile.TemporaryFile(mode='a+b', dir=tmp_dir)
writer = csv.writer(tmp_file)
writer.writerows(tmp_list)
# The number of results might have exceeded the Salesforce batch limit
# so check whether there are more results and retrieve them if so.
length = len(response['records'])
while not response['done']:
response = self.query_more(response['nextRecordsUrl'], identifier_is_url=True, **kwargs)
writer.writerows(parse_results(fields, response))
length += len(response['records'])
if not length % 10000:
logger.info('Requested {0} lines...'.format(length))
logger.info('Requested a total of {0} lines.'.format(length))
tmp_file.seek(0)
return tmp_file
# Generic Rest Function
[docs] def restful(self, path, params):
"""
Allows you to make a direct REST call if you know the path
Arguments:
:param path: The path of the request. Example: sobjects/User/ABC123/password'
:param params: dict of parameters to pass to the path
"""
url = self._get_norm_base_url() + path
response = requests.get(url, headers=self._get_rest_headers(), params=params)
if response.status_code != 200:
raise Exception(response)
json_result = response.json(object_pairs_hook=OrderedDict)
if len(json_result) == 0:
return None
else:
return json_result
[docs] def create_operation_job(self, operation, obj, external_id_field_name=None, content_type=None):
"""
Creates a new SF job that for doing any operation (insert, upsert, update, delete, query)
:param operation: delete, insert, query, upsert, update, hardDelete. Must be lowercase.
:param obj: Parent SF object
:param external_id_field_name: Optional.
"""
if not self.has_active_session():
self.start_session()
response = requests.post(self._get_create_job_url(),
headers=self._get_create_job_headers(),
data=self._get_create_job_xml(operation, obj, external_id_field_name, content_type))
response.raise_for_status()
root = ET.fromstring(response.text)
job_id = root.find('%sid' % self.API_NS).text
return job_id
[docs] def get_job_details(self, job_id):
"""
Gets all details for existing job
:param job_id: job_id as returned by 'create_operation_job(...)'
:return: job info as xml
"""
response = requests.get(self._get_job_details_url(job_id))
response.raise_for_status()
return response
[docs] def abort_job(self, job_id):
"""
Abort an existing job. When a job is aborted, no more records are processed.
Changes to data may already have been committed and aren't rolled back.
:param job_id: job_id as returned by 'create_operation_job(...)'
:return: abort response as xml
"""
response = requests.post(self._get_abort_job_url(job_id),
headers=self._get_abort_job_headers(),
data=self._get_abort_job_xml())
response.raise_for_status()
return response
[docs] def close_job(self, job_id):
"""
Closes job
:param job_id: job_id as returned by 'create_operation_job(...)'
:return: close response as xml
"""
if not job_id or not self.has_active_session():
raise Exception("Can not close job without valid job_id and an active session.")
response = requests.post(self._get_close_job_url(job_id),
headers=self._get_close_job_headers(),
data=self._get_close_job_xml())
response.raise_for_status()
return response
[docs] def create_batch(self, job_id, data, file_type):
"""
Creates a batch with either a string of data or a file containing data.
If a file is provided, this will pull the contents of the file_target into memory when running.
That shouldn't be a problem for any files that meet the Salesforce single batch upload
size limit (10MB) and is done to ensure compressed files can be uploaded properly.
:param job_id: job_id as returned by 'create_operation_job(...)'
:param data:
:return: Returns batch_id
"""
if not job_id or not self.has_active_session():
raise Exception("Can not create a batch without a valid job_id and an active session.")
headers = self._get_create_batch_content_headers(file_type)
headers['Content-Length'] = str(len(data))
response = requests.post(self._get_create_batch_url(job_id),
headers=headers,
data=data)
response.raise_for_status()
root = ET.fromstring(response.text)
batch_id = root.find('%sid' % self.API_NS).text
return batch_id
[docs] def block_on_batch(self, job_id, batch_id, sleep_time_seconds=5, max_wait_time_seconds=-1):
"""
Blocks until @batch_id is completed or failed.
:param job_id:
:param batch_id:
:param sleep_time_seconds:
:param max_wait_time_seconds:
"""
if not job_id or not batch_id or not self.has_active_session():
raise Exception("Can not block on a batch without a valid batch_id, job_id and an active session.")
start_time = time.time()
status = {}
while max_wait_time_seconds < 0 or time.time() - start_time < max_wait_time_seconds:
status = self._get_batch_info(job_id, batch_id)
logger.info("Batch %s Job %s in state %s. %s records processed. %s records failed." %
(batch_id, job_id, status['state'], status['num_processed'], status['num_failed']))
if status['state'].lower() in ["completed", "failed"]:
return status
time.sleep(sleep_time_seconds)
raise Exception("Batch did not complete in %s seconds. Final status was: %s" % (sleep_time_seconds, status))
[docs] def get_batch_results(self, job_id, batch_id):
"""
DEPRECATED: Use `get_batch_result_ids`
"""
warnings.warn("get_batch_results is deprecated and only returns one batch result. Please use get_batch_result_ids")
return self.get_batch_result_ids(job_id, batch_id)[0]
[docs] def get_batch_result_ids(self, job_id, batch_id):
"""
Get result IDs of a batch that has completed processing.
:param job_id: job_id as returned by 'create_operation_job(...)'
:param batch_id: batch_id as returned by 'create_batch(...)'
:return: list of batch result IDs to be used in 'get_batch_result(...)'
"""
response = requests.get(self._get_batch_results_url(job_id, batch_id),
headers=self._get_batch_info_headers())
response.raise_for_status()
root = ET.fromstring(response.text)
result_ids = [r.text for r in root.findall('%sresult' % self.API_NS)]
return result_ids
[docs] def get_batch_result(self, job_id, batch_id, result_id):
"""
Gets result back from Salesforce as whatever type was originally sent in create_batch (xml, or csv).
:param job_id:
:param batch_id:
:param result_id:
"""
response = requests.get(self._get_batch_result_url(job_id, batch_id, result_id),
headers=self._get_session_headers())
response.raise_for_status()
return response.content
def _get_batch_info(self, job_id, batch_id):
response = requests.get(self._get_batch_info_url(job_id, batch_id),
headers=self._get_batch_info_headers())
response.raise_for_status()
root = ET.fromstring(response.text)
result = {
"state": root.find('%sstate' % self.API_NS).text,
"num_processed": root.find('%snumberRecordsProcessed' % self.API_NS).text,
"num_failed": root.find('%snumberRecordsFailed' % self.API_NS).text,
}
if root.find('%sstateMessage' % self.API_NS) is not None:
result['state_message'] = root.find('%sstateMessage' % self.API_NS).text
return result
def _get_login_url(self):
server = "login" if not self.sandbox_name else "test"
return "https://%s.salesforce.com/services/Soap/u/%s" % (server, self.API_VERSION)
def _get_base_url(self):
return "https://%s/services" % self.hostname
def _get_bulk_base_url(self):
# Expands on Base Url for Bulk
return "%s/async/%s" % (self._get_base_url(), self.API_VERSION)
def _get_norm_base_url(self):
# Expands on Base Url for Norm
return "%s/data/v%s" % (self._get_base_url(), self.API_VERSION)
def _get_norm_query_url(self):
# Expands on Norm Base Url
return "%s/query" % self._get_norm_base_url()
def _get_create_job_url(self):
# Expands on Bulk url
return "%s/job" % (self._get_bulk_base_url())
def _get_job_id_url(self, job_id):
# Expands on Job Creation url
return "%s/%s" % (self._get_create_job_url(), job_id)
def _get_job_details_url(self, job_id):
# Expands on basic Job Id url
return self._get_job_id_url(job_id)
def _get_abort_job_url(self, job_id):
# Expands on basic Job Id url
return self._get_job_id_url(job_id)
def _get_close_job_url(self, job_id):
# Expands on basic Job Id url
return self._get_job_id_url(job_id)
def _get_create_batch_url(self, job_id):
# Expands on basic Job Id url
return "%s/batch" % (self._get_job_id_url(job_id))
def _get_batch_info_url(self, job_id, batch_id):
# Expands on Batch Creation url
return "%s/%s" % (self._get_create_batch_url(job_id), batch_id)
def _get_batch_results_url(self, job_id, batch_id):
# Expands on Batch Info url
return "%s/result" % (self._get_batch_info_url(job_id, batch_id))
def _get_batch_result_url(self, job_id, batch_id, result_id):
# Expands on Batch Results url
return "%s/%s" % (self._get_batch_results_url(job_id, batch_id), result_id)
def _get_login_headers(self):
headers = {
'Content-Type': "text/xml; charset=UTF-8",
'SOAPAction': 'login'
}
return headers
def _get_session_headers(self):
headers = {
'X-SFDC-Session': self.session_id
}
return headers
def _get_norm_session_headers(self):
headers = {
'Authorization': 'Bearer %s' % self.session_id
}
return headers
def _get_rest_headers(self):
headers = self._get_norm_session_headers()
headers['Content-Type'] = 'application/json'
return headers
def _get_job_headers(self):
headers = self._get_session_headers()
headers['Content-Type'] = "application/xml; charset=UTF-8"
return headers
def _get_create_job_headers(self):
return self._get_job_headers()
def _get_abort_job_headers(self):
return self._get_job_headers()
def _get_close_job_headers(self):
return self._get_job_headers()
def _get_create_batch_content_headers(self, content_type):
headers = self._get_session_headers()
content_type = 'text/csv' if content_type.lower() == 'csv' else 'application/xml'
headers['Content-Type'] = "%s; charset=UTF-8" % content_type
return headers
def _get_batch_info_headers(self):
return self._get_session_headers()
def _get_login_xml(self):
return """<?xml version="1.0" encoding="utf-8" ?>
<env:Envelope xmlns:xsd="http://www.w3.org/2001/XMLSchema"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns:env="http://schemas.xmlsoap.org/soap/envelope/">
<env:Body>
<n1:login xmlns:n1="urn:partner.soap.sforce.com">
<n1:username>%s</n1:username>
<n1:password>%s%s</n1:password>
</n1:login>
</env:Body>
</env:Envelope>
""" % (self.username, self.password, self.security_token if self.sandbox_name is None else self.sb_security_token)
def _get_create_job_xml(self, operation, obj, external_id_field_name, content_type):
external_id_field_name_element = "" if not external_id_field_name else \
"\n<externalIdFieldName>%s</externalIdFieldName>" % external_id_field_name
# Note: "Unable to parse job" error may be caused by reordering fields.
# ExternalIdFieldName element must be before contentType element.
return """<?xml version="1.0" encoding="UTF-8"?>
<jobInfo xmlns="http://www.force.com/2009/06/asyncapi/dataload">
<operation>%s</operation>
<object>%s</object>
%s
<contentType>%s</contentType>
</jobInfo>
""" % (operation, obj, external_id_field_name_element, content_type)
def _get_abort_job_xml(self):
return """<?xml version="1.0" encoding="UTF-8"?>
<jobInfo xmlns="http://www.force.com/2009/06/asyncapi/dataload">
<state>Aborted</state>
</jobInfo>
"""
def _get_close_job_xml(self):
return """<?xml version="1.0" encoding="UTF-8"?>
<jobInfo xmlns="http://www.force.com/2009/06/asyncapi/dataload">
<state>Closed</state>
</jobInfo>
"""