Commit d21204c2 authored by mahdisellami's avatar mahdisellami
Browse files

Added Idea Categorization resource (including some pre-trained

classifiers)
parent 0cd906d1
......@@ -85,6 +85,16 @@
<version>3.9.1</version>
<classifier>models</classifier> <!-- will get the dependent model jars -->
</dependency>
<dependency>
<groupId>org.apache.opennlp</groupId>
<artifactId>opennlp-tools</artifactId>
<version>1.8.4</version>
</dependency>
<dependency>
<groupId>de.julielab</groupId>
<artifactId>aliasi-lingpipe</artifactId>
<version>4.1.0</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
......
package org.tmms.classification.api;
import java.io.Serializable;
import javax.persistence.Entity;
import javax.persistence.GeneratedValue;
import javax.persistence.GenerationType;
import javax.persistence.Id;
import javax.persistence.NamedQueries;
import javax.persistence.NamedQuery;
import javax.persistence.Table;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
/**
* A class to store Categorization data.
*/
@ApiModel(description = "Categorization Model")
@Entity
@Table(name = "categorizations")
@NamedQueries({
@NamedQuery(name = "org.tmms.classification.api.Categorization.findAll", query = "select c from Categorization c"),
@NamedQuery(name = "org.tmms.classification.api.Categorization.findByCategory", query = "select c from Categorization c "
+ "where c.category like :category ") })
public class Categorization implements Serializable {
/**
* Entity's unique identifier.
*/
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private long id;
/**
* Input idea.
*/
private String idea;
/**
* output category.
*/
private String category;
/**
* A no-argument constructor.
*/
public Categorization() {
}
/**
* A constructor to create Categorizations. Id is not passed, cause it's
* auto-generated by RDBMS.
*
* @param idea
* the input idea.
* @param category
* the output category.
*/
public Categorization(String idea, String category) {
this.idea = idea;
this.category = category;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((category == null) ? 0 : category.hashCode());
result = prime * result + (int) (id ^ (id >>> 32));
result = prime * result + ((idea == null) ? 0 : idea.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Categorization other = (Categorization) obj;
if (category == null) {
if (other.category != null)
return false;
} else if (!category.equals(other.category))
return false;
if (id != other.id)
return false;
if (idea == null) {
if (other.idea != null)
return false;
} else if (!idea.equals(other.idea))
return false;
return true;
}
@JsonProperty
@ApiModelProperty(value = "The Categorization id", example = "1")
public long getId() {
return id;
}
public void setId(long id) {
this.id = id;
}
@JsonProperty
@ApiModelProperty(value = "The input idea", example = "This is my idea.")
public String getIdea() {
return idea;
}
public void setIdea(String idea) {
this.idea = idea;
}
@JsonProperty
@ApiModelProperty(value = "The Categorization result", example = "Sports")
public String getCategory() {
return category;
}
public void setCategory(String category) {
this.category = category;
}
}
......@@ -6,11 +6,13 @@ import javax.servlet.DispatcherType;
import javax.servlet.FilterRegistration;
import org.eclipse.jetty.servlets.CrossOriginFilter;
import org.tmms.classification.api.Categorization;
import org.tmms.classification.api.Lemma;
import org.tmms.classification.api.NamedEntityRecognition;
import org.tmms.classification.api.PartOfSpeech;
import org.tmms.classification.api.Sentiment;
import org.tmms.classification.api.Stemming;
import org.tmms.classification.db.CategorizationDAO;
import org.tmms.classification.db.LemmaDAO;
import org.tmms.classification.db.NamedEntityRecognitionDAO;
import org.tmms.classification.db.PartOfSpeechDAO;
......@@ -18,6 +20,7 @@ import org.tmms.classification.db.SentimentDAO;
import org.tmms.classification.db.StemmingDAO;
import org.tmms.classification.health.DatabaseHealthCheck;
import org.tmms.classification.resources.ApiInfoResource;
import org.tmms.classification.resources.CategorizationResource;
import org.tmms.classification.resources.LemmaAnnotationResource;
import org.tmms.classification.resources.NamedEntityRecognitionResource;
import org.tmms.classification.resources.PartOfSpeechTaggingResource;
......@@ -40,7 +43,8 @@ public class classificationserviceApplication extends Application<classification
* Hibernate bundle.
*/
private final HibernateBundle<classificationserviceConfiguration> hibernateBundle = new HibernateBundle<classificationserviceConfiguration>(
Sentiment.class, Lemma.class, PartOfSpeech.class, NamedEntityRecognition.class, Stemming.class) {
Sentiment.class, Lemma.class, PartOfSpeech.class, NamedEntityRecognition.class, Stemming.class,
Categorization.class) {
@Override
public DataSourceFactory getDataSourceFactory(classificationserviceConfiguration configuration) {
return configuration.getDataSourceFactory();
......@@ -99,6 +103,7 @@ public class classificationserviceApplication extends Application<classification
final PartOfSpeechDAO partOfSpeechDAO = new PartOfSpeechDAO(hibernateBundle.getSessionFactory());
final NamedEntityRecognitionDAO nerDAO = new NamedEntityRecognitionDAO(hibernateBundle.getSessionFactory());
final StemmingDAO stemDAO = new StemmingDAO(hibernateBundle.getSessionFactory());
final CategorizationDAO catDAO = new CategorizationDAO(hibernateBundle.getSessionFactory());
final NamedEntityRecognitionResource NERResource = new NamedEntityRecognitionResource(nerDAO);
environment.jersey().register(NERResource);
......@@ -115,6 +120,9 @@ public class classificationserviceApplication extends Application<classification
final StemmingResource stemResource = new StemmingResource(stemDAO);
environment.jersey().register(stemResource);
final CategorizationResource catResource = new CategorizationResource(catDAO);
environment.jersey().register(catResource);
// final TemplateHealthCheck healthCheck = new
// TemplateHealthCheck(configuration.getTemplate());
// environment.healthChecks().register("template", healthCheck);
......
package org.tmms.classification.db;
import java.util.List;
import org.hibernate.SessionFactory;
import org.tmms.classification.api.Categorization;
import io.dropwizard.hibernate.AbstractDAO;
public class CategorizationDAO extends AbstractDAO<Categorization> {
/**
* Constructor.
*
* @param sessionFactory
* Hibernate session factory.
*/
public CategorizationDAO(SessionFactory sessionFactory) {
super(sessionFactory);
}
/**
* Method returns all Categorizations stored in the database.
*
* @return list of all Categorizations stored in the database
*/
@SuppressWarnings("unchecked")
public List<Categorization> findAll() {
return list(namedQuery("org.tmms.classification.api.Categorization.findAll"));
}
/**
* Method looks for a Categorization by its id.
*
* @param id
* the id of a Categorization we are looking for.
* @return Categorization the Categorization with the given id
*/
public Categorization findById(long id) {
return get(id);
}
/**
* Method inserts a Categorization.
*
* @param c
* the Categorization to insert.
* @return Categorization the Categorization inserted.
*/
public Categorization insert(Categorization c) {
return persist(c);
}
/**
* Method looks for Categorizations by category value.
*
* @param category
* the category value we are looking for.
* @return List containing the found categorizations with the given category
* value.
*/
@SuppressWarnings("unchecked")
public List<Categorization> findByCategory(String category) {
return list(namedQuery("org.tmms.classification.api.Categorization.findByCategory").setParameter("category",
category));
}
}
package org.tmms.classification.db;
import java.util.List;
import java.util.Optional;
import org.hibernate.SessionFactory;
import org.tmms.classification.api.Lemma;
......@@ -25,6 +24,7 @@ public class LemmaDAO extends AbstractDAO<Lemma> {
*
* @return list of all Lemmatizations stored in the database
*/
@SuppressWarnings("unchecked")
public List<Lemma> findAll() {
return list(namedQuery("org.tmms.classification.api.Lemma.findAll"));
}
......
package org.tmms.classification.resources;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.core.MediaType;
import org.tmms.classification.api.Categorization;
import org.tmms.classification.db.CategorizationDAO;
import com.aliasi.classify.LMClassifier;
import com.aliasi.util.AbstractExternalizable;
import com.codahale.metrics.annotation.Timed;
import io.dropwizard.hibernate.UnitOfWork;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import io.swagger.annotations.ApiParam;
import io.swagger.annotations.SwaggerDefinition;
import io.swagger.annotations.Tag;
import opennlp.tools.doccat.DoccatModel;
import opennlp.tools.doccat.DocumentCategorizer;
import opennlp.tools.doccat.DocumentCategorizerME;
@Path("/categorize")
@Api(value = "/categorize", tags = { "categorize" })
@SwaggerDefinition(tags = { @Tag(name = "categorize", description = "Idea Categorization Resrouce") })
@Produces(MediaType.APPLICATION_JSON)
public class CategorizationResource {
private CategorizationDAO categorizaionDAO;
public CategorizationResource(CategorizationDAO categorizaionDAO) {
this.categorizaionDAO = categorizaionDAO;
}
@SuppressWarnings("rawtypes")
@GET
@Path("/LM")
@Timed
@UnitOfWork
@ApiOperation("Returns the category of the input idea using a pre-trained Language Model Classifier")
public Categorization categorizeIdeaWithLM(
@ApiParam(value = "input idea", required = true) @QueryParam("idea") String idea)
throws ClassNotFoundException, IOException {
LMClassifier mClassifier = (LMClassifier) AbstractExternalizable
.readObject(new File("src/main/resources/assets/ideas_classifier.model"));
String category = mClassifier.classify(idea).bestCategory();
return categorizaionDAO.insert(new Categorization(idea, category));
}
@GET
@Path("/NB")
@Timed
@UnitOfWork
@ApiOperation("Returns the category of the input idea using a pre-trained Naive-Bayes Classifier")
public Categorization categorizeIdeaWithNB(
@ApiParam(value = "input idea", required = true) @QueryParam("idea") String idea) throws IOException {
InputStream modelIn = new FileInputStream("src/main/resources/assets/de-idea-classifier-naive-bayes.bin");
DoccatModel model = new DoccatModel(modelIn);
DocumentCategorizer doccat = new DocumentCategorizerME(model);
String[] docWords = idea.replaceAll("[^A-Za-z]", " ").split(" ");
double[] aProbs = doccat.categorize(docWords);
String category = doccat.getBestCategory(aProbs);
return categorizaionDAO.insert(new Categorization(idea, category));
}
@GET
@Path("/ME")
@Timed
@UnitOfWork
@ApiOperation("Returns the category of the input idea using a pre-trained Maximum Entropy Classifier")
public Categorization categorizeIdeaWithME(
@ApiParam(value = "input idea", required = true) @QueryParam("idea") String idea) throws IOException {
InputStream modelIn = new FileInputStream("src/main/resources/assets/de-idea-classifier-maxent.bin");
DoccatModel model = new DoccatModel(modelIn);
DocumentCategorizer doccat = new DocumentCategorizerME(model);
String[] docWords = idea.replaceAll("[^A-Za-z]", " ").split(" ");
double[] aProbs = doccat.categorize(docWords);
String category = doccat.getBestCategory(aProbs);
return categorizaionDAO.insert(new Categorization(idea, category));
}
@GET
@Path("/all")
@UnitOfWork
@ApiOperation("Returns all the previous categorizations")
public List<Categorization> findAll() {
return categorizaionDAO.findAll();
}
@GET
@Path("/category")
@UnitOfWork
@ApiOperation("Returns the stored ideas with the given category")
public List<Categorization> findBySentimentValue(
@ApiParam(value = "The category", allowableValues = "Accommodation,Citizen Participation,Environment,Financial,General,Information,Mobility,Organisation,Port,Reuse,Safety,Social,Sports", required = true) @QueryParam("category") String category) {
return categorizaionDAO.findByCategory(category);
}
@GET
@Path("/id")
@UnitOfWork
@ApiOperation("Returns the categorization with the given id")
public Categorization findById(@ApiParam(value = "id", required = true) @QueryParam("id") long id) {
return categorizaionDAO.findById(id);
}
}
......@@ -62,5 +62,16 @@
<constraints nullable="false"/>
</column>
</createTable>
<createTable tableName="categorizations">
<column name="id" type="bigint" autoIncrement="true">
<constraints primaryKey="true" nullable="false"/>
</column>
<column name="idea" type="varchar(255)">
<constraints nullable="false"/>
</column>
<column name="category" type="varchar(255)">
<constraints nullable="false"/>
</column>
</createTable>
</changeSet>
</databaseChangeLog>
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment