diff --git a/source/notebooks/sagemaker_fraud_detection.ipynb b/source/notebooks/sagemaker_fraud_detection.ipynb index 8530358..2f50ddf 100644 --- a/source/notebooks/sagemaker_fraud_detection.ipynb +++ b/source/notebooks/sagemaker_fraud_detection.ipynb @@ -310,12 +310,9 @@ }, "outputs": [], "source": [ - "from sagemaker.serializers import CSVSerializer\n", + "from sagemaker.serializers import CSVSerializer \n", "from sagemaker.deserializers import JSONDeserializer\n", - " \n", - "rcf_predictor.content_type = 'text/csv'\n", - "rcf_predictor.serializer = CSVSerializer()\n", - "rcf_predictor.accept = 'application/json'\n", + "rcf_predictor.serializer = CSVSerializer() \n", "rcf_predictor.deserializer = JSONDeserializer()" ] }, @@ -372,6 +369,7 @@ "metadata": {}, "outputs": [], "source": [ + "%pip install seaborn --q\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "sns.set(color_codes=True)" @@ -655,12 +653,12 @@ "source": [ "from sagemaker.serializers import CSVSerializer\n", "\n", + "\n", "predictor = clf.deploy(initial_instance_count=1,\n", " model_name=\"{}-xgb\".format(config.SOLUTION_PREFIX),\n", " endpoint_name=\"{}-xgb\".format(config.SOLUTION_PREFIX),\n", " instance_type=instance_type,\n", - " serializer=CSVSerializer(),\n", - " deserializer=None)" + " serializer=CSVSerializer())" ] }, { @@ -696,7 +694,7 @@ " split_array = np.array_split(data, int(data.shape[0] / float(rows) + 1))\n", " predictions = ''\n", " for array in split_array:\n", - " predictions = ','.join([predictions, current_predictor.predict(array).decode('utf-8')])\n", + " predictions = ','.join([predictions, current_predictor.predict(array)])\n", "\n", " return np.fromstring(predictions[1:], sep=',')" ] @@ -868,6 +866,18 @@ "pip install aws_requests_auth" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "for name in (\"botocore.credentials\", \"botocore\", \"boto3\", \"s3transfer\", \"urllib3\"):\n", + " logging.getLogger(name).setLevel(logging.WARNING) # or logging.ERROR / logging.CRITICAL" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1065,13 +1075,25 @@ "# Specify input and output formats.\n", "smote_predictor.content_type = 'text/csv'\n", "csv_serializer = CSVSerializer()\n", - "smote_predictor.serializer = csv_serializer\n", "\n", "# Set the deserializer to handle the response from the inference endpoint\n", "#csv_deserializer = CSVDeserializer()\n", "#smote_predictor.deserializer = csv_deserializer" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.serializers import CSVSerializer\n", + "from sagemaker.deserializers import StringDeserializer # predictions often come back as text\n", + "\n", + "smote_predictor.serializer = CSVSerializer() # sets Content-Type to text/csv\n", + "smote_predictor.deserializer = StringDeserializer()" + ] + }, { "cell_type": "code", "execution_count": null,