|
import json |
|
from functools import lru_cache |
|
from logging import INFO, error, getLogger, info |
|
from typing import Sequence, Any, Dict, Iterable, Iterator, NamedTuple |
|
|
|
from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel |
|
from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc |
|
from clarifai_grpc.grpc.api.status import status_code_pb2 |
|
|
|
getLogger().setLevel(INFO) |
|
|
|
_MAX_WORKFLOW_INPUT = 32 |
|
|
|
|
|
class Row(NamedTuple): |
|
""" |
|
Represent a list of arguments of the external function. |
|
|
|
The first argument is always a row number. |
|
It should present both in request and response. |
|
""" |
|
|
|
row_number: int |
|
input_type: str |
|
url: str |
|
workflow_name: str |
|
app_id: str |
|
pat: str |
|
|
|
|
|
class EasyException(Exception): |
|
pass |
|
|
|
|
|
def lambda_handler(event, context): |
|
try: |
|
info(f"Start event {event}") |
|
# We need to return list of rows together with its row id. |
|
# For that purpose we create a map from (url, workflow_id) to row_id map. |
|
input_map = {row: row.row_number for row in get_rows(event)} |
|
rows = tuple(input_map) |
|
|
|
if not rows: |
|
return make_success_response(context, []) |
|
|
|
info(f"Inputs are: {rows}") |
|
outputs = predict(*rows) |
|
|
|
snowflake_rows = list(match_input_with_row_number(rows, outputs, input_map)) |
|
|
|
return make_success_response(context, snowflake_rows) |
|
except Exception as err: |
|
error(err, exc_info=True) |
|
if isinstance(err, EasyException): |
|
error_message = str(err) |
|
else: |
|
error_message = "An unexpected exception occurred" |
|
return make_error_response(context, error_message) |
|
|
|
|
|
def match_input_with_row_number(inputs, outputs, row_id_map): |
|
""" |
|
Match Clarifai output with the row number. |
|
""" |
|
for key, prediction in zip(inputs, outputs): |
|
response_as_variant = make_prediction_response(prediction) |
|
yield [row_id_map[key], response_as_variant] |
|
|
|
|
|
def get_rows(event) -> Iterator[Row]: |
|
""" |
|
Each row is an array, where the first argument is a row number and other are function arguments. |
|
""" |
|
event_body = event.get("body", None) |
|
if not event_body: |
|
raise EasyException("Body not set") |
|
if not isinstance(event_body, str): |
|
raise EasyException("Body is expected to be string") |
|
|
|
payload = json.loads(event_body) |
|
data = payload.get("data", []) |
|
if not isinstance(data, list): |
|
raise EasyException("Body data is expected to be a list") |
|
|
|
for row in data: |
|
yield Row(*row) |
|
|
|
|
|
def make_chunks(sequence, size): |
|
start = 0 |
|
while start < len(sequence): |
|
yield sequence[start : start + size + 1] |
|
start += size |
|
|
|
|
|
def get_input_chunk(rows: Sequence[Row]): |
|
first_row = rows[0] |
|
|
|
workflow_id = first_row.workflow_name |
|
app_id = first_row.app_id |
|
pat = first_row.pat |
|
|
|
meta = make_metadata(pat) |
|
user_app_id = resources_pb2.UserAppIDSet(app_id=app_id) |
|
|
|
for chunk in make_chunks(rows, _MAX_WORKFLOW_INPUT): |
|
inputs = make_input(chunk) |
|
yield inputs, meta, workflow_id, user_app_id |
|
|
|
|
|
def predict(*data: Row, threshold=0.9) -> Sequence[Sequence[Dict[str, str]]]: |
|
stub = make_connection() |
|
# This version assumes that all URL uses the same workflow_id |
|
|
|
image_concepts = [] |
|
|
|
for inputs, meta, workflow_id, auth in get_input_chunk(data): |
|
post_workflow_results_response = stub.PostWorkflowResults( |
|
service_pb2.PostWorkflowResultsRequest( |
|
user_app_id=auth, |
|
workflow_id=workflow_id, |
|
inputs=inputs, |
|
), |
|
metadata=meta, |
|
) |
|
info("Got response status %s", post_workflow_results_response.status) |
|
if post_workflow_results_response.status.code != status_code_pb2.SUCCESS: |
|
error_message = ( |
|
f"Failed to get prediction: " |
|
f" {post_workflow_results_response.status.description}" |
|
f" {post_workflow_results_response.status.details}" |
|
) |
|
for result in post_workflow_results_response.results: |
|
for output in result.outputs: |
|
if ( |
|
output.status.code != status_code_pb2.SUCCESS |
|
and output.status.details |
|
): |
|
error_message = output.status.details |
|
break |
|
raise EasyException(error_message) |
|
|
|
for results in post_workflow_results_response.results: |
|
outputs = results.outputs |
|
concept_names = [] |
|
for output in outputs: |
|
for region in output.data.regions: |
|
concept_names.extend( |
|
{ |
|
"text": region.data.text.raw, |
|
"name": concept.name, |
|
"value": f"{concept.value:.5f}", |
|
} |
|
for concept in region.data.concepts |
|
if concept.value >= threshold |
|
) |
|
|
|
concept_names.extend( |
|
{"name": concept.name, "value": f"{concept.value:.5f}"} |
|
for concept in output.data.concepts |
|
if concept.value >= threshold |
|
) |
|
image_concepts.append(sorted(concept_names, key=lambda item: item["name"])) |
|
return image_concepts |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
def make_connection(): |
|
channel = ClarifaiChannel.get_json_channel() |
|
return service_pb2_grpc.V2Stub(channel) |
|
|
|
|
|
def make_metadata(pat): |
|
return (("authorization", f"Key {pat}"),) |
|
|
|
|
|
def is_image(row: Row) -> bool: |
|
return row.input_type.lower().startswith(("image", "img")) |
|
|
|
|
|
def _make_input(row: Row) -> resources_pb2.Input: |
|
if is_image(row): |
|
data = resources_pb2.Data(image=resources_pb2.Image(url=row.url)) |
|
else: |
|
data = resources_pb2.Data(text=resources_pb2.Text(url=row.url)) |
|
|
|
return resources_pb2.Input(data=data) |
|
|
|
|
|
def make_input(rows: Iterable[Row]) -> Sequence[resources_pb2.Input]: |
|
return [_make_input(row) for row in rows] |
|
|
|
|
|
def make_success_response(context, data: Any): |
|
body = { |
|
"request_id": context.aws_request_id, |
|
"data": data, |
|
} |
|
info("Response to return: %s", body) |
|
return { |
|
"statusCode": 200, |
|
"body": json.dumps(body), |
|
} |
|
|
|
|
|
def make_error_response(context, error_message: str): |
|
info("Error to return: %s", error_message) |
|
return { |
|
"statusCode": 400, |
|
"body": json.dumps( |
|
{ |
|
"request_id": context.aws_request_id, |
|
"error": {"message": error_message}, |
|
}, |
|
), |
|
} |
|
|
|
|
|
def make_prediction_response(labels: Sequence) -> Dict[str, Any]: |
|
""" |
|
Create response for prediction. |
|
|
|
For each row we return only single value, and this value could be a VARIANT or ARRAY. |
|
Python equivalent for a VARIANT is a dict and for ARRAY is a list. |
|
""" |
|
|
|
return {"tags": labels} |