Published on

AWSとEmpathとRaspberryPi Voicekitで音声感情認識

Authors
  • avatar
    Name
    Kikusan
    Twitter

AWS Lambda,APIGateWay,DynamoDB,S3,IoTCore:Empath:RspberryPiVoicekitを使って、音声感情認識に取り組む。

AWSの基本設定

基本東京リージョンで作業

IAM

ユーザとロールを作成する。ポリシーは以下をアタッチする。(※フルアクセス)

ユーザはアクセスキーも発行する。

  • AWSLambda_FullAccess
  • AmazonAPIGatewayAdministrator
  • AWSIoTFullAccess
  • AmazonS3FullAccess

IoTCore

IoT Coreは

モノに証明書が付き、証明書にポリシーをアタッチする。

モノごとに証明書を変えてもいいし、証明書に複数のポリシーをアタッチしてもいい。

適当な名前で「モノ」を作成すると、1-Click証明書作成ができる。

  • 証明書.pem
  • パブリックキー.public.key
  • プライベートキー.private.key
  • AmazonRootCA1.pem

をダウンロード。

できた証明書にポリシーをアタッチする。(※フルアクセス)

{
  "Version": "2012-10-17",
  "Statement": [
    {
      "Effect": "Allow",
      "Action": [
        "iot:*"
      ],
      "Resource": [
        "*"
      ]
    }
  ]
}

Lambda

二つ関数を作成する

  1. IoTCoreから声を受け取り、EmpathAPIで感情を取得し、DynamoDBに保存し、結果を別のトピックに投げる関数
  • Lambda, DynamoDB, IoTCoreのポリシーがあるロールで使用する。
  • 環境変数にIOT_ENDPOINTと、ENPATH_URL,EMPATH_KEYを設定する。
  • requestsは標準モジュールにないので、レイヤーの作成から以下コマンドでできるzipファイルをアップロードしてレイヤを作成。(展開するとupload>python>module)
mkdir python
pip install requests -t python/
zip -r upload.zip python/
  • レイヤができたらこの関数でレイヤーの追加する。
  • トリガーにはAWS IoTのカスタムIoTルールで、ルールクエリーには以下を指定する。
SELECT * FROM "empath/+/input"
  • Lambda関数
import json
import boto3
import base64
import os
import requests
from datetime import datetime, timedelta


IOT_ENDPOINT = os.environ['IOT_ENDPOINT']
EMPATH_KEY = os.environ['EMPATH_KEY']
EMPATH_URL = os.environ['EMPATH_URL']
iot = boto3.client('iot-data', endpoint_url=IOT_ENDPOINT)
dynamoDB = boto3.resource("dynamodb")
table = dynamoDB.Table("empath")


def lambda_handler(event, context):
    topic = f"empath/{event['device-name']}/return"

    try:
        wav_bytes = base64.b64decode(event['payload'])
        empath_payload = {'apikey': EMPATH_KEY}
        file = {'wav': wav_bytes}
        res = requests.post(EMPATH_URL, params=empath_payload, files=file)
        
        now = (datetime.utcnow() + timedelta(hours=9))
        num_now = now.second + 10**2 *now.minute + 10**4 * now.hour + 10**6 * now.day + 10**8 * now.month + 10**10 * now.year
        table.put_item(
            Item={
                "CustomerId": event['device-name'],
                "TimeStamp": num_now,
                "empath": res.text
            }
        )
        
        iot.publish(
            topic=topic,
            qos=0,
            payload=res.text
        )
        
        return {
            'statusCode': 200,
            'body': 'Succeeded.'
        }
    
    except Exception as e:
        print("【error】", e)
        return {
            'statusCode': 500,
            'body': 'Failed.'
        }
  1. APIGatewayからIDを受け取り、最新の感情を返す関数(getDbData)
  • Lambda, DynamoDB, APIGatewayのポリシーがあるロールで使用する。
  • トリガーはAPIGatewayのgetdbdataを指定する(あとでAPIGateway側から指定する)
  • Lambda関数
import boto3
import json
from boto3.dynamodb.conditions import Key
from decimal import Decimal


def decimal_default_proc(obj):
    if isinstance(obj, Decimal):
        return float(obj)
    raise TypeError


def lambda_handler(event, context):
    try:
        dynamoDB = boto3.resource("dynamodb")
        table = dynamoDB.Table("empath")
        
        queryData = table.query(
            KeyConditionExpression = Key("CustomerId").eq(event['id']), #キー情報
            ScanIndexForward = False, #降順
            Limit = 10
        )
        
        return {
            'statusCode': 200,
            'body': json.dumps(queryData, default=decimal_default_proc)
        }

    except Exception as e:
        print(e)
        return {
            'statusCode': 500,
            'body': str(e)
        }

DynamoDB

以下設定でtableを作成する。

  • table : empath
    • partitionkey : CustomerId
    • sortkey : TimeStamp

APIGateway

以下設定でメソッドを作成する。

  • REST API 構築
  • API名 : getDbData
  • リソースの作成 : getDbData / getdbdata
  • メソッドの作成 : GET (チェックをクリック)
    • 統合タイプ : Lambda関数
    • 関数名 : getDbData
    • getdbdata>CORSの有効化
    • メソッドリクエスト : クエリ文字列パラメータ
      • id (チェックをクリック)
    • メソッドリクエスト : HTTPリクエストヘッダー
      • Content-Type (チェックをクリック)
    • 統合リクエスト : マッピングテンプレート
      • リクエスト本文のパススルー : テンプレートが定義されていない場合 (推奨)
      • Content-Type : application/json
      • テンプレート
#set($inputRoot=$input.path('$'))
{
  "id": "$input.params('id')"
}

これで APIGateway のテストで

  • id=1

  • Content-Type=application/json で送ると、Lambdaのreturnが帰ってくる。

  • APIのデプロイ

    • ステージ : v1
    • ⇒ URLが取得できる。

S3

感情を見るページをS3でホストする。

  • バケットの作成
    • 公開アクセス権限なし
    • ファイルアップロード : index.html, index.js
    • S3>バケット>プロパティ
      • 静的ウェブサイトホスティング
        • 有効
        • インデックスドキュメント : index.html
        • エラードキュメント : index.html
    • S3>バケット>アクセス許可
      • ブロックパブリックアクセス : すべて許可
      • バケットポリシー
{
   "Version":"2012-10-17",
   "Statement":[{
    "Sid":"PublicReadForGetBucketObjects",
         "Effect":"Allow",
      "Principal": "*",
       "Action":["s3:GetObject"],
       "Resource":["arn:aws:s3:::[バケット名]/*"
       ]
     }
   ]
 }

これでindex.htmlのURLが取得できる。

index.htmlのソース

jsonで感情を取得して、google chartsで表示する。

  • index.html
<html>
  <head>
    <script type="text/javascript" src="https://www.gstatic.com/charts/loader.js"></script>
    <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
    <script src="https://cdnjs.cloudflare.com/ajax/libs/moment.js/2.11.2/moment.min.js" type="text/javascript"></script>
    <script src="https://cdnjs.cloudflare.com/ajax/libs/crypto-js/3.1.2/components/core-min.js" type="text/javascript"></script>
    <script src="https://cdnjs.cloudflare.com/ajax/libs/crypto-js/3.1.2/components/hmac-min.js" type="text/javascript"></script>
    <script src="https://cdnjs.cloudflare.com/ajax/libs/crypto-js/3.1.2/components/sha256-min.js" type="text/javascript"></script>
    <script src="https://cdnjs.cloudflare.com/ajax/libs/paho-mqtt/1.0.1/mqttws31.min.js" type="text/javascript"></script>
    <script src="index.js"></script>
  </head>
  <body>
    <div id="empath_chart" style="width: 900px; height: 500px"></div>
  </body>
</html>
  • index.js
/**
 * google charts
 */

const ID = 'voicekit';
const COUNT = 10;
const REGION =  'ap-northeast-1';
const END_POINT = 'XXXXX-ats.iot.ap-northeast-1.amazonaws.com'; //iot core endpoint
const ACCESS_KEY = 'XXXXXXXX'; //aws access key
const SECRET_KEY = 'XXXXXXXX'; //aws private key

const URL = "https://XXXXXX.execute-api.ap-northeast-1.amazonaws.com/v1/getdbdata?id=" + ID //api gateway

google.charts.load('current', {'packages':['corechart']});
const dataArray = [[{label:'TimeStamp', id:'timestamp', type: 'date'}, 'calm', 'anger', 'joy', 'sorrow', 'energy']];

function init() {
  $.getJSON(URL).done(function(res) {
      if(res.statusCode != 200){
        console.error("Error params is not valid:", res.body);
        return;
      }
      let body = JSON.parse(res.body);
      for(let item of body.Items){
        data = []
        s_time = String(item.TimeStamp)
        Y = Number(s_time.slice(0,4));
        M = Number(s_time.slice(4,6)) - 1;
        D = Number(s_time.slice(6,8));
        H = Number(s_time.slice(8,10));
        m = Number(s_time.slice(10,12));
        s = Number(s_time.slice(12,14));
        data.push(new Date(Y, M, D, H, m, s));
        let empath = JSON.parse(item.empath);
        data.push(empath.calm);
        data.push(empath.anger);
        data.push(empath.joy);
        data.push(empath.sorrow);
        data.push(empath.energy);
        dataArray.push(data);
      }
      dataArray.sort(function(a, b) {
        return (a[0]- b[0]);
      });
    })
  .fail(function(jqxhr, textStatus, error) {
    let err = textStatus + ", " + error;
    console.error("Error getting empath json:", err);
  });
}

function drawChart() {
  var data = google.visualization.arrayToDataTable(dataArray);

  var options = {
    title: 'Your Empath',
    curveType: 'function',
    legend: { position: 'bottom' }
  };

  var chart = new google.visualization.LineChart(document.getElementById('empath_chart'));

  chart.draw(data, options);
}

/**
 * mqtt websocket
 */

function SigV4Utils(){}

SigV4Utils.sign = function(key, msg) {
  var hash = CryptoJS.HmacSHA256(msg, key);
  return hash.toString(CryptoJS.enc.Hex);
};

SigV4Utils.sha256 = function(msg) {
  var hash = CryptoJS.SHA256(msg);
  return hash.toString(CryptoJS.enc.Hex);
};

SigV4Utils.getSignatureKey = function(key, dateStamp, regionName, serviceName) {
  var kDate = CryptoJS.HmacSHA256(dateStamp, 'AWS4' + key);
  var kRegion = CryptoJS.HmacSHA256(regionName, kDate);
  var kService = CryptoJS.HmacSHA256(serviceName, kRegion);
  var kSigning = CryptoJS.HmacSHA256('aws4_request', kService);
  return kSigning;
};

function createEndpoint(regionName, awsIotEndpoint, accessKey, secretKey) {
  var time = moment.utc();
  var dateStamp = time.format('YYYYMMDD');
  var amzdate = dateStamp + 'T' + time.format('HHmmss') + 'Z';
  var service = 'iotdevicegateway';
  var region = regionName;
  var secretKey = secretKey;
  var accessKey = accessKey;
  var algorithm = 'AWS4-HMAC-SHA256';
  var method = 'GET';
  var canonicalUri = '/mqtt';
  var host = awsIotEndpoint;

  var credentialScope = dateStamp + '/' + region + '/' + service + '/' + 'aws4_request';
  var canonicalQuerystring = 'X-Amz-Algorithm=AWS4-HMAC-SHA256';
  canonicalQuerystring += '&X-Amz-Credential=' + encodeURIComponent(accessKey + '/' + credentialScope);
  canonicalQuerystring += '&X-Amz-Date=' + amzdate;
  canonicalQuerystring += '&X-Amz-SignedHeaders=host';

  var canonicalHeaders = 'host:' + host + '\n';
  var payloadHash = SigV4Utils.sha256('');
  var canonicalRequest = method + '\n' + canonicalUri + '\n' + canonicalQuerystring + '\n' + canonicalHeaders + '\nhost\n' + payloadHash;

  var stringToSign = algorithm + '\n' +  amzdate + '\n' +  credentialScope + '\n' +  SigV4Utils.sha256(canonicalRequest);
  var signingKey = SigV4Utils.getSignatureKey(secretKey, dateStamp, region, service);
  var signature = SigV4Utils.sign(signingKey, stringToSign);

  canonicalQuerystring += '&X-Amz-Signature=' + signature;
  return 'wss://' + host + canonicalUri + '?' + canonicalQuerystring;
}

function subscribe() {
  client.subscribe("empath/" + ID + "/return");
  console.log("subscribed");
}

function send(content) {
  var message = new Paho.MQTT.Message(content);
  message.destinationName = "Test/chat";
  client.send(message);
  console.log("sent");
}

function onMessage(message) {
  console.log("message received: " + message.payloadString);
  data = []
  console.log(new Date());
  data.push(new Date());
  let empath = JSON.parse(message.payloadString);
  data.push(empath.calm);
  data.push(empath.anger);
  data.push(empath.joy);
  data.push(empath.sorrow);
  data.push(empath.energy);
  if (dataArray.length > (1 + COUNT)){
    dataArray.splice(1, 1); // 二番目を削除
  }
  dataArray.push(data);
  drawChart();
}

/**
 * main
 */

init();
setTimeout(() => {
  google.charts.setOnLoadCallback(drawChart);  
}, 1000)

var endpoint = createEndpoint(
    REGION,
    END_POINT,
    ACCESS_KEY,
    SECRET_KEY);
var clientId = Math.random().toString(36).substring(7);
var client = new Paho.MQTT.Client(endpoint, clientId);
var connectOptions = {
  useSSL: true,
  timeout: 3,
  mqttVersion: 4,
  onSuccess: subscribe
};
client.connect(connectOptions);
client.onMessageArrived = onMessage;
client.onConnectionLost = function(e) { console.log(e) };

ラズパイ側

ラズパイからは音声をpyAudioで取得し、それをIoTCoreへMQTTで送信する。

  • 必要ライブラリのインポート
sudo apt install libportaudio2 libportaudiocpp0 portaudio19-dev
pip install pyaudio
pip3 install requests
sudo pip3 install python-dotenv
sudo pip3 install AWSIoTPythonSDK
  • ファイルの配置
empath
├─app
│      .env
│      app.py
│      config.py
│      gateway.py
│      recorder.py
│      temp.wav
└─aws
        XXXXX-certificate.pem.crt        
        XXXXX-private.pem.key
        XXXXX-public.pem.key
        AmazonRootCA1.pem
  • app.py
from gateway import Gateway
import sys

g = Gateway()
print('Press Ctrl+C to exit.')
g.loop()
  • config.py
import os
import pyaudio
from dotenv import load_dotenv

load_dotenv()

DEVICE_NAME = os.environ.get("DEVICE_NAME")
AWS_ENDPOINT = os.environ.get("AWS_ENDPOINT")
AWS_ROOTCA = os.environ.get("AWS_ROOTCA")
AWS_KEY = os.environ.get("AWS_KEY")
AWS_CERT = os.environ.get("AWS_CERT")

WAV_FILE = os.environ.get("WAV_FILE")

RECORD_SECONDS = int(os.environ.get("RECORD_SECONDS")) #録音する時間の長さ(秒)
I_DEVICE_INDEX = int(os.environ.get("DEVICE_INDEX")) #録音デバイスのインデックス番号
CHANNELS = int(os.environ.get("CHANNELS")) #モノラル
RATE = int(os.environ.get("RATE")) #サンプルレート
CHUNK = int(os.environ.get("CHUNK")) #データ点数
FORMAT = pyaudio.paInt16 #音声のフォーマット
  • gateway.py
from io import BytesIO
from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTClient
from aiy.leds import Leds, Color

import base64
import json
import sys
import time

from recorder import Recorder
import config


class Gateway(object):

    def __init__(self):
        self.recorder = Recorder()
        self.client = self._create_client()
        self.leds = Leds()

    def loop(self):
        self.client.subscribe('empath/'+config.DEVICE_NAME+'/return', 
                              1, self.show_empath)
        err_f = True
        while err_f:
            try:
                self.recorder.record()
                bwav = self.recorder.to_byte()
                self.send(bwav)
                time.sleep(10)
            except KeyboardInterrupt as e:
                print(e)
                self.leds.update(Leds.rgb_off())
                err_f = False

    def send(self, bytes_data):
        data = base64.b64encode(bytes_data).decode("utf-8")
        self.client.publish(
            topic='empath/'+config.DEVICE_NAME+'/input',
            payload=json.dumps({
                "device-name": config.DEVICE_NAME,
                "payload": data,
            }),
            QoS=1
        )
    
    def show_empath(self, client, userdata, message):
        data = json.loads(message.payload.decode('utf-8'))
        print(data)
        if (data['error'] > 0):
            print(data)
            leds.update(Leds.rgb_off())
        # 値が最大のキーを取得
        max_k = max(data, key=data.get)
        print(max_k)
        if max_k == 'calm':
            self.leds.update(Leds.rgb_on(Color.BLUE))
        elif max_k == 'anger':
            self.leds.update(Leds.rgb_on(Color.RED))
        elif max_k == 'joy':
            self.leds.update(Leds.rgb_on(Color.YELLOW))
        elif max_k == 'sorrow':
            self.leds.update(Leds.rgb_on(Color.GREEN))
        elif max_k == 'energy':
            self.leds.update(Leds.rgb_on(Color.PURPLE))

    def _create_client(self):
        # AWS IoTのクライアント作成
        client = AWSIoTMQTTClient(config.DEVICE_NAME)
        # クライアントの初期設定
        client.configureEndpoint(config.AWS_ENDPOINT, 8883)
        client.configureAutoReconnectBackoffTime(1, 32, 20)
        client.configureOfflinePublishQueueing(-1)
        client.configureDrainingFrequency(2)
        client.configureConnectDisconnectTimeout(300)
        client.configureMQTTOperationTimeout(10)
        client.configureCredentials(config.AWS_ROOTCA,
                                    config.AWS_KEY,
                                    config.AWS_CERT)

        client.connect(60)
        client.publish('gateway/'+config.DEVICE_NAME+'/stat', 'connected.', 1)

        return client


if __name__ == '__main__':
    g = Gateway()
    g.recorder.record()
    bwav = g.recorder.to_byte()
    g.send(bwav)
  • recorder.py
from io import BytesIO
from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTClient

import json
import wave
import pyaudio

import config


class Recorder(object):

    def __init__(self):
        self.audio = pyaudio.PyAudio()
        self.stream = self._create_stream()

    def _create_stream(self):

        stream = self.audio.open(format=config.FORMAT, channels=config.CHANNELS,
                rate=config.RATE, input=True,
                input_device_index = config.I_DEVICE_INDEX, # 録音デバイスのインデックス番号
                frames_per_buffer=config.CHUNK)

        return stream

    def record(self):

        frames = []
        #--------------録音開始---------------
        print ("recording...")
        for i in range(0, int(config.RATE / config.CHUNK * config.RECORD_SECONDS)):
            data = self.stream.read(config.CHUNK, exception_on_overflow=False)
            frames.append(data)

        print ("finished recording")
        #--------------録音終了---------------

        with wave.open(config.WAV_FILE, 'wb') as wave_file: # write binary
            wave_file.setnchannels(config.CHANNELS)
            wave_file.setsampwidth(self.audio.get_sample_size(config.FORMAT))
            wave_file.setframerate(config.RATE)
            wave_file.writeframes(b''.join(frames))

    def to_byte(self):
        with open(config.WAV_FILE, 'rb') as wave_file: # read binary
            bwav = wave_file.read() # to bytes
            return bwav

if __name__ == '__main__':
    audio = pyaudio.PyAudio()
    stream = audio.open(format=config.FORMAT, channels=config.CHANNELS,
        rate=config.RATE, input=True,
        input_device_index = config.I_DEVICE_INDEX, # 録音デバイスのインデックス番号
        frames_per_buffer=config.CHUNK)
    frames = []
    #--------------録音開始---------------
    print ("recording...")
    for i in range(0, int(config.RATE / config.CHUNK * config.RECORD_SECONDS)):
        data = stream.read(config.CHUNK)
        frames.append(data)

    print ("finished recording")
    #--------------録音終了---------------

    with wave.open(config.WAV_FILE, 'wb') as wave_file: # write binary
        wave_file.setnchannels(config.CHANNELS)
        wave_file.setsampwidth(audio.get_sample_size(config.FORMAT))
        wave_file.setframerate(config.RATE)
        wave_file.writeframes(b''.join(frames))
  • .env
DEVICE_NAME=voicekit
AWS_ENDPOINT=XXXXX-ats.iot.ap-northeast-1.amazonaws.com
AWS_ROOTCA=/home/pi/empath/aws/AmazonRootCA1.pem
AWS_KEY=/home/pi/empath/aws/XXXXX-private.pem.key
AWS_CERT=/home/pi/empath/aws/XXXXX-certificate.pem.crt

WAV_FILE=temp.wav

RECORD_SECONDS=3 # 録音秒数
DEVICE_INDEX=0 # デバイス番号
CHANNELS=1 # モノラル
RATE=11025 # サンプルレート
CHUNK = 2048 # データ点数

実行

ラズパイでapp.py を実行すると、index.htmlで感情が見られる。