Source code for ynlu.sdk.client

from typing import List

from gql import gql, Client
from gql.transport.requests import RequestsHTTPTransport

from .model import Model


[docs]class NLUClient(object): """ Client which could contain multiple intent clfs NOTE: Only support predicting for now """ URL = 'https://ynlu.yoctol.com/graphql' def __init__( self, token: str, expected_retries: int = 1, url: str = URL, ): self.token = token self._transport = RequestsHTTPTransport( url=url, use_json=True, ) self._transport.headers = { "User-Agent": "Mozilla/5.0 (X11; Ubuntu; " + "Linux x86_64; rv:58.0) Gecko/20100101 Firefox/58.0", "Authorization": "Bearer {}".format(self.token), "content-type": "application/json", } self._client = self.build_client(retries=expected_retries) self._classifier_ids, self._classifier_names = self.fetch_all_available_clf_ids_and_names() self._models = { clf_id: Model(clf_id, self._client) for clf_id in self._classifier_ids }
[docs] def fetch_all_available_clf_ids_and_names(self) -> List[str]: print("Fetching all classifier's id") projects_raw_query = """ query projects { projects { id } } """ projects_query = gql(projects_raw_query) projects_result = self._client.execute(projects_query) if 'projects' not in projects_result: return [] projects_id = [projects['id'] for projects in projects_result['projects']] clfs_id = [] clfs_name = [] for p_id in projects_id: clfs_raw_query = """ query project($id: Int!) { project(id: $id) { classifiers { id name } } } """ clfs_query = gql(clfs_raw_query) variable_values = { 'id': p_id, } clfs_result = self._client.execute( clfs_query, variable_values=variable_values, ) for clf in clfs_result['project']['classifiers']: clfs_id.append(clf['id']) clfs_name.append(clf['name']) return clfs_id, clfs_name
[docs] def get_all_available_clf_ids(self) -> List[str]: return self._classifier_ids[:]
[docs] def get_all_available_clf_names(self) -> List[str]: return self._classifier_names[:]
[docs] def build_client(self, retries: int): return Client( retries=retries, transport=self._transport, fetch_schema_from_transport=True, )
def __getitem__(self, key): return self.get_model_by_id(key)
[docs] def get_model_by_id( self, classifier_id: str, ) -> Model: self.check_clf_id(classifier_id) return self._models[classifier_id]
[docs] def get_model_by_name( self, classifier_name: str, ) -> Model: self.check_clf_name(classifier_name) id_index = self._classifier_names.index(classifier_name) classifier_id = self._classifier_ids[id_index] return self._models[classifier_id]
[docs] def check_clf_id( self, classifier_id: str, ) -> None: if classifier_id not in self._classifier_ids: raise ValueError('Illegal clf id {}'.format(classifier_id))
[docs] def check_clf_name( self, classifier_name: str, ) -> None: if classifier_name not in self._classifier_names: raise ValueError('Illegal clf name {}'.format(classifier_name))