Source code for luigi.contrib.salesforce

# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import abc
import logging
import warnings
import xml.etree.ElementTree as ET
from collections import OrderedDict
import re
import csv
import tempfile

import luigi
from luigi import Task

logger = logging.getLogger('luigi-interface')

try:
    import requests
except ImportError:
    logger.warning("This module requires the python package 'requests'.")

try:
    from urlparse import urlsplit
except ImportError:
    from urllib.parse import urlsplit


[docs]def get_soql_fields(soql): """ Gets queried columns names. """ soql_fields = re.search('(?<=select)(?s)(.*)(?=from)', soql, re.IGNORECASE) # get fields soql_fields = re.sub(' ', '', soql_fields.group()) # remove extra spaces soql_fields = re.sub('\t', '', soql_fields) # remove tabs fields = re.split(',|\n|\r|', soql_fields) # split on commas and newlines fields = [field for field in fields if field != ''] # remove empty strings return fields
[docs]def ensure_utf(value): return value.encode("utf-8") if isinstance(value, unicode) else value
[docs]def parse_results(fields, data): """ Traverses ordered dictionary, calls _traverse_results() to recursively read into the dictionary depth of data """ master = [] for record in data['records']: # for each 'record' in response row = [None] * len(fields) # create null list the length of number of columns for obj, value in record.iteritems(): # for each obj in record if not isinstance(value, (dict, list, tuple)): # if not data structure if obj in fields: row[fields.index(obj)] = ensure_utf(value) elif isinstance(value, dict) and obj != 'attributes': # traverse down into object path = obj _traverse_results(value, fields, row, path) master.append(row) return master
def _traverse_results(value, fields, row, path): """ Helper method for parse_results(). Traverses through ordered dict and recursively calls itself when encountering a dictionary """ for f, v in value.iteritems(): # for each item in obj field_name = '{path}.{name}'.format(path=path, name=f) if path else f if not isinstance(v, (dict, list, tuple)): # if not data structure if field_name in fields: row[fields.index(field_name)] = ensure_utf(v) elif isinstance(v, dict) and f != 'attributes': # it is a dict _traverse_results(v, fields, row, field_name)
[docs]class salesforce(luigi.Config): """ Config system to get config vars from 'salesforce' section in configuration file. Did not include sandbox_name here, as the user may have multiple sandboxes. """ username = luigi.Parameter(default='') password = luigi.Parameter(default='') security_token = luigi.Parameter(default='') # sandbox token sb_security_token = luigi.Parameter(default='')
[docs]class QuerySalesforce(Task): @abc.abstractproperty def object_name(self): """ Override to return the SF object we are querying. Must have the SF "__c" suffix if it is a customer object. """ return None @property def use_sandbox(self): """ Override to specify use of SF sandbox. True iff we should be uploading to a sandbox environment instead of the production organization. """ return False @property def sandbox_name(self): """Override to specify the sandbox name if it is intended to be used.""" return None @abc.abstractproperty def soql(self): """Override to return the raw string SOQL or the path to it.""" return None @property def is_soql_file(self): """Override to True if soql property is a file path.""" return False @property def content_type(self): """ Override to use a different content type. Salesforce allows XML, CSV, ZIP_CSV, or ZIP_XML. Defaults to CSV. """ return "CSV"
[docs] def run(self): if self.use_sandbox and not self.sandbox_name: raise Exception("Parameter sf_sandbox_name must be provided when uploading to a Salesforce Sandbox") sf = SalesforceAPI(salesforce().username, salesforce().password, salesforce().security_token, salesforce().sb_security_token, self.sandbox_name) job_id = sf.create_operation_job('query', self.object_name, content_type=self.content_type) logger.info("Started query job %s in salesforce for object %s" % (job_id, self.object_name)) batch_id = '' msg = '' try: if self.is_soql_file: with open(self.soql, 'r') as infile: self.soql = infile.read() batch_id = sf.create_batch(job_id, self.soql, self.content_type) logger.info("Creating new batch %s to query: %s for job: %s." % (batch_id, self.object_name, job_id)) status = sf.block_on_batch(job_id, batch_id) if status['state'].lower() == 'failed': msg = "Batch failed with message: %s" % status['state_message'] logger.error(msg) # don't raise exception if it's b/c of an included relationship # normal query will execute (with relationship) after bulk job is closed if 'foreign key relationships not supported' not in status['state_message'].lower(): raise Exception(msg) else: result_ids = sf.get_batch_result_ids(job_id, batch_id) # If there's only one result, just download it, otherwise we need to merge the resulting downloads if len(result_ids) == 1: data = sf.get_batch_result(job_id, batch_id, result_ids[0]) with open(self.output().path, 'wb') as outfile: outfile.write(data) else: # Download each file to disk, and then merge into one. # Preferring to do it this way so as to minimize memory consumption. for i, result_id in enumerate(result_ids): logger.info("Downloading batch result %s for batch: %s and job: %s" % (result_id, batch_id, job_id)) with open("%s.%d" % (self.output().path, i), 'wb') as outfile: outfile.write(sf.get_batch_result(job_id, batch_id, result_id)) logger.info("Merging results of batch %s" % batch_id) self.merge_batch_results(result_ids) finally: logger.info("Closing job %s" % job_id) sf.close_job(job_id) if 'state_message' in status and 'foreign key relationships not supported' in status['state_message'].lower(): logger.info("Retrying with REST API query") data_file = sf.query_all(self.soql) reader = csv.reader(data_file) with open(self.output().path, 'wb') as outfile: writer = csv.writer(outfile, dialect='excel') for row in reader: writer.writerow(row)
[docs] def merge_batch_results(self, result_ids): """ Merges the resulting files of a multi-result batch bulk query. """ outfile = open(self.output().path, 'w') if self.content_type.lower() == 'csv': for i, result_id in enumerate(result_ids): with open("%s.%d" % (self.output().path, i), 'r') as f: header = f.readline() if i == 0: outfile.write(header) for line in f: outfile.write(line) else: raise Exception("Batch result merging not implemented for %s" % self.content_type) outfile.close()
[docs]class SalesforceAPI(object): """ Class used to interact with the SalesforceAPI. Currently provides only the methods necessary for performing a bulk upload operation. """ API_VERSION = 34.0 SOAP_NS = "{urn:partner.soap.sforce.com}" API_NS = "{http://www.force.com/2009/06/asyncapi/dataload}" def __init__(self, username, password, security_token, sb_token=None, sandbox_name=None): self.username = username self.password = password self.security_token = security_token self.sb_security_token = sb_token self.sandbox_name = sandbox_name if self.sandbox_name: self.username += ".%s" % self.sandbox_name self.session_id = None self.server_url = None self.hostname = None
[docs] def start_session(self): """ Starts a Salesforce session and determines which SF instance to use for future requests. """ if self.has_active_session(): raise Exception("Session already in progress.") response = requests.post(self._get_login_url(), headers=self._get_login_headers(), data=self._get_login_xml()) response.raise_for_status() root = ET.fromstring(response.text) for e in root.iter("%ssessionId" % self.SOAP_NS): if self.session_id: raise Exception("Invalid login attempt. Multiple session ids found.") self.session_id = e.text for e in root.iter("%sserverUrl" % self.SOAP_NS): if self.server_url: raise Exception("Invalid login attempt. Multiple server urls found.") self.server_url = e.text if not self.has_active_session(): raise Exception("Invalid login attempt resulted in null sessionId [%s] and/or serverUrl [%s]." % (self.session_id, self.server_url)) self.hostname = urlsplit(self.server_url).hostname
[docs] def has_active_session(self): return self.session_id and self.server_url
[docs] def query(self, query, **kwargs): """ Return the result of a Salesforce SOQL query as a dict decoded from the Salesforce response JSON payload. :param query: the SOQL query to send to Salesforce, e.g. "SELECT id from Lead WHERE email = 'a@b.com'" """ params = {'q': query} response = requests.get(self._get_norm_query_url(), headers=self._get_rest_headers(), params=params, **kwargs) if response.status_code != requests.codes.ok: raise Exception(response.content) return response.json()
[docs] def query_more(self, next_records_identifier, identifier_is_url=False, **kwargs): """ Retrieves more results from a query that returned more results than the batch maximum. Returns a dict decoded from the Salesforce response JSON payload. :param next_records_identifier: either the Id of the next Salesforce object in the result, or a URL to the next record in the result. :param identifier_is_url: True if `next_records_identifier` should be treated as a URL, False if `next_records_identifer` should be treated as an Id. """ if identifier_is_url: # Don't use `self.base_url` here because the full URI is provided url = (u'https://{instance}{next_record_url}' .format(instance=self.hostname, next_record_url=next_records_identifier)) else: url = self._get_norm_query_url() + '{next_record_id}' url = url.format(next_record_id=next_records_identifier) response = requests.get(url, headers=self._get_rest_headers(), **kwargs) response.raise_for_status() return response.json()
[docs] def query_all(self, query, **kwargs): """ Returns the full set of results for the `query`. This is a convenience wrapper around `query(...)` and `query_more(...)`. The returned dict is the decoded JSON payload from the final call to Salesforce, but with the `totalSize` field representing the full number of results retrieved and the `records` list representing the full list of records retrieved. :param query: the SOQL query to send to Salesforce, e.g. `SELECT Id FROM Lead WHERE Email = "waldo@somewhere.com"` """ # Make the initial query to Salesforce response = self.query(query, **kwargs) # get fields fields = get_soql_fields(query) # put fields and first page of results into a temp list to be written to TempFile tmp_list = [fields] tmp_list.extend(parse_results(fields, response)) tmp_dir = luigi.configuration.get_config().get('salesforce', 'local-tmp-dir', None) tmp_file = tempfile.TemporaryFile(mode='a+b', dir=tmp_dir) writer = csv.writer(tmp_file) writer.writerows(tmp_list) # The number of results might have exceeded the Salesforce batch limit # so check whether there are more results and retrieve them if so. length = len(response['records']) while not response['done']: response = self.query_more(response['nextRecordsUrl'], identifier_is_url=True, **kwargs) writer.writerows(parse_results(fields, response)) length += len(response['records']) if not length % 10000: logger.info('Requested {0} lines...'.format(length)) logger.info('Requested a total of {0} lines.'.format(length)) tmp_file.seek(0) return tmp_file
# Generic Rest Function
[docs] def restful(self, path, params): """ Allows you to make a direct REST call if you know the path Arguments: :param path: The path of the request. Example: sobjects/User/ABC123/password' :param params: dict of parameters to pass to the path """ url = self._get_norm_base_url() + path response = requests.get(url, headers=self._get_rest_headers(), params=params) if response.status_code != 200: raise Exception(response) json_result = response.json(object_pairs_hook=OrderedDict) if len(json_result) == 0: return None else: return json_result
[docs] def create_operation_job(self, operation, obj, external_id_field_name=None, content_type=None): """ Creates a new SF job that for doing any operation (insert, upsert, update, delete, query) :param operation: delete, insert, query, upsert, update, hardDelete. Must be lowercase. :param obj: Parent SF object :param external_id_field_name: Optional. """ if not self.has_active_session(): self.start_session() response = requests.post(self._get_create_job_url(), headers=self._get_create_job_headers(), data=self._get_create_job_xml(operation, obj, external_id_field_name, content_type)) response.raise_for_status() root = ET.fromstring(response.text) job_id = root.find('%sid' % self.API_NS).text return job_id
[docs] def get_job_details(self, job_id): """ Gets all details for existing job :param job_id: job_id as returned by 'create_operation_job(...)' :return: job info as xml """ response = requests.get(self._get_job_details_url(job_id)) response.raise_for_status() return response
[docs] def abort_job(self, job_id): """ Abort an existing job. When a job is aborted, no more records are processed. Changes to data may already have been committed and aren't rolled back. :param job_id: job_id as returned by 'create_operation_job(...)' :return: abort response as xml """ response = requests.post(self._get_abort_job_url(job_id), headers=self._get_abort_job_headers(), data=self._get_abort_job_xml()) response.raise_for_status() return response
[docs] def close_job(self, job_id): """ Closes job :param job_id: job_id as returned by 'create_operation_job(...)' :return: close response as xml """ if not job_id or not self.has_active_session(): raise Exception("Can not close job without valid job_id and an active session.") response = requests.post(self._get_close_job_url(job_id), headers=self._get_close_job_headers(), data=self._get_close_job_xml()) response.raise_for_status() return response
[docs] def create_batch(self, job_id, data, file_type): """ Creates a batch with either a string of data or a file containing data. If a file is provided, this will pull the contents of the file_target into memory when running. That shouldn't be a problem for any files that meet the Salesforce single batch upload size limit (10MB) and is done to ensure compressed files can be uploaded properly. :param job_id: job_id as returned by 'create_operation_job(...)' :param data: :return: Returns batch_id """ if not job_id or not self.has_active_session(): raise Exception("Can not create a batch without a valid job_id and an active session.") headers = self._get_create_batch_content_headers(file_type) headers['Content-Length'] = str(len(data)) response = requests.post(self._get_create_batch_url(job_id), headers=headers, data=data) response.raise_for_status() root = ET.fromstring(response.text) batch_id = root.find('%sid' % self.API_NS).text return batch_id
[docs] def block_on_batch(self, job_id, batch_id, sleep_time_seconds=5, max_wait_time_seconds=-1): """ Blocks until @batch_id is completed or failed. :param job_id: :param batch_id: :param sleep_time_seconds: :param max_wait_time_seconds: """ if not job_id or not batch_id or not self.has_active_session(): raise Exception("Can not block on a batch without a valid batch_id, job_id and an active session.") start_time = time.time() status = {} while max_wait_time_seconds < 0 or time.time() - start_time < max_wait_time_seconds: status = self._get_batch_info(job_id, batch_id) logger.info("Batch %s Job %s in state %s. %s records processed. %s records failed." % (batch_id, job_id, status['state'], status['num_processed'], status['num_failed'])) if status['state'].lower() in ["completed", "failed"]: return status time.sleep(sleep_time_seconds) raise Exception("Batch did not complete in %s seconds. Final status was: %s" % (sleep_time_seconds, status))
[docs] def get_batch_results(self, job_id, batch_id): """ DEPRECATED: Use `get_batch_result_ids` """ warnings.warn("get_batch_results is deprecated and only returns one batch result. Please use get_batch_result_ids") return self.get_batch_result_ids(job_id, batch_id)[0]
[docs] def get_batch_result_ids(self, job_id, batch_id): """ Get result IDs of a batch that has completed processing. :param job_id: job_id as returned by 'create_operation_job(...)' :param batch_id: batch_id as returned by 'create_batch(...)' :return: list of batch result IDs to be used in 'get_batch_result(...)' """ response = requests.get(self._get_batch_results_url(job_id, batch_id), headers=self._get_batch_info_headers()) response.raise_for_status() root = ET.fromstring(response.text) result_ids = [r.text for r in root.findall('%sresult' % self.API_NS)] return result_ids
[docs] def get_batch_result(self, job_id, batch_id, result_id): """ Gets result back from Salesforce as whatever type was originally sent in create_batch (xml, or csv). :param job_id: :param batch_id: :param result_id: """ response = requests.get(self._get_batch_result_url(job_id, batch_id, result_id), headers=self._get_session_headers()) response.raise_for_status() return response.content
def _get_batch_info(self, job_id, batch_id): response = requests.get(self._get_batch_info_url(job_id, batch_id), headers=self._get_batch_info_headers()) response.raise_for_status() root = ET.fromstring(response.text) result = { "state": root.find('%sstate' % self.API_NS).text, "num_processed": root.find('%snumberRecordsProcessed' % self.API_NS).text, "num_failed": root.find('%snumberRecordsFailed' % self.API_NS).text, } if root.find('%sstateMessage' % self.API_NS) is not None: result['state_message'] = root.find('%sstateMessage' % self.API_NS).text return result def _get_login_url(self): server = "login" if not self.sandbox_name else "test" return "https://%s.salesforce.com/services/Soap/u/%s" % (server, self.API_VERSION) def _get_base_url(self): return "https://%s/services" % self.hostname def _get_bulk_base_url(self): # Expands on Base Url for Bulk return "%s/async/%s" % (self._get_base_url(), self.API_VERSION) def _get_norm_base_url(self): # Expands on Base Url for Norm return "%s/data/v%s" % (self._get_base_url(), self.API_VERSION) def _get_norm_query_url(self): # Expands on Norm Base Url return "%s/query" % self._get_norm_base_url() def _get_create_job_url(self): # Expands on Bulk url return "%s/job" % (self._get_bulk_base_url()) def _get_job_id_url(self, job_id): # Expands on Job Creation url return "%s/%s" % (self._get_create_job_url(), job_id) def _get_job_details_url(self, job_id): # Expands on basic Job Id url return self._get_job_id_url(job_id) def _get_abort_job_url(self, job_id): # Expands on basic Job Id url return self._get_job_id_url(job_id) def _get_close_job_url(self, job_id): # Expands on basic Job Id url return self._get_job_id_url(job_id) def _get_create_batch_url(self, job_id): # Expands on basic Job Id url return "%s/batch" % (self._get_job_id_url(job_id)) def _get_batch_info_url(self, job_id, batch_id): # Expands on Batch Creation url return "%s/%s" % (self._get_create_batch_url(job_id), batch_id) def _get_batch_results_url(self, job_id, batch_id): # Expands on Batch Info url return "%s/result" % (self._get_batch_info_url(job_id, batch_id)) def _get_batch_result_url(self, job_id, batch_id, result_id): # Expands on Batch Results url return "%s/%s" % (self._get_batch_results_url(job_id, batch_id), result_id) def _get_login_headers(self): headers = { 'Content-Type': "text/xml; charset=UTF-8", 'SOAPAction': 'login' } return headers def _get_session_headers(self): headers = { 'X-SFDC-Session': self.session_id } return headers def _get_norm_session_headers(self): headers = { 'Authorization': 'Bearer %s' % self.session_id } return headers def _get_rest_headers(self): headers = self._get_norm_session_headers() headers['Content-Type'] = 'application/json' return headers def _get_job_headers(self): headers = self._get_session_headers() headers['Content-Type'] = "application/xml; charset=UTF-8" return headers def _get_create_job_headers(self): return self._get_job_headers() def _get_abort_job_headers(self): return self._get_job_headers() def _get_close_job_headers(self): return self._get_job_headers() def _get_create_batch_content_headers(self, content_type): headers = self._get_session_headers() content_type = 'text/csv' if content_type.lower() == 'csv' else 'application/xml' headers['Content-Type'] = "%s; charset=UTF-8" % content_type return headers def _get_batch_info_headers(self): return self._get_session_headers() def _get_login_xml(self): return """<?xml version="1.0" encoding="utf-8" ?> <env:Envelope xmlns:xsd="http://www.w3.org/2001/XMLSchema" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:env="http://schemas.xmlsoap.org/soap/envelope/"> <env:Body> <n1:login xmlns:n1="urn:partner.soap.sforce.com"> <n1:username>%s</n1:username> <n1:password>%s%s</n1:password> </n1:login> </env:Body> </env:Envelope> """ % (self.username, self.password, self.security_token if self.sandbox_name is None else self.sb_security_token) def _get_create_job_xml(self, operation, obj, external_id_field_name, content_type): external_id_field_name_element = "" if not external_id_field_name else \ "\n<externalIdFieldName>%s</externalIdFieldName>" % external_id_field_name # Note: "Unable to parse job" error may be caused by reordering fields. # ExternalIdFieldName element must be before contentType element. return """<?xml version="1.0" encoding="UTF-8"?> <jobInfo xmlns="http://www.force.com/2009/06/asyncapi/dataload"> <operation>%s</operation> <object>%s</object> %s <contentType>%s</contentType> </jobInfo> """ % (operation, obj, external_id_field_name_element, content_type) def _get_abort_job_xml(self): return """<?xml version="1.0" encoding="UTF-8"?> <jobInfo xmlns="http://www.force.com/2009/06/asyncapi/dataload"> <state>Aborted</state> </jobInfo> """ def _get_close_job_xml(self): return """<?xml version="1.0" encoding="UTF-8"?> <jobInfo xmlns="http://www.force.com/2009/06/asyncapi/dataload"> <state>Closed</state> </jobInfo> """