edu.stanford.nlp.parser.lexparser
Class SplittingGrammarExtractor

java.lang.Object
  extended by edu.stanford.nlp.parser.lexparser.SplittingGrammarExtractor

public class SplittingGrammarExtractor
extends java.lang.Object

This class is a reimplementation of Berkeley's state splitting grammar. This work is experimental and still in progress. There are several extremely important pieces to implement:

  1. this code should use log probabilities throughout instead of multiplying tiny numbers
  2. time efficiency of the training code is fawful
  3. there are better ways to extract parses using this grammar than the method in ExhaustivePCFGParser
  4. we should also implement cascading parsers that let us shortcircuit low quality parses earlier (which could possibly benefit non-split parsers as well)
  5. when looping, we should short circuit if we go too many loops
  6. ought to smooth as per page 436

Author:
John Bauer

Constructor Summary
SplittingGrammarExtractor(Options op)
           
 
Method Summary
 void buildGrammars()
           
 void buildStateIndex()
           
 void countMergeEffects(Tree tree, java.util.Map<java.lang.String,double[]> totalStateMass, java.util.Map<java.lang.String,double[]> deltaAnnotations)
           
 void countMergeEffects(Tree tree, java.util.Map<java.lang.String,double[]> totalStateMass, java.util.Map<java.lang.String,double[]> deltaAnnotations, java.util.IdentityHashMap<Tree,double[]> probIn, java.util.IdentityHashMap<Tree,double[]> probOut)
           
 void countOriginalStates()
          Count all the internal labels in all the trees, and set their initial state counts to 1.
 void extract(java.util.Collection<Tree> treeList)
           
 void extract(java.util.Collection<Tree> trees1, double weight1, java.util.Collection<Tree> trees2, double weight2)
          First, we do a few setup steps.
 int getStateSplitCount(java.lang.String label)
           
 int getStateSplitCount(Tree tree)
           
 void mergeStates()
           
 void mergeTransitions(Tree parent, java.util.IdentityHashMap<Tree,double[][]> oldUnaryTransitions, java.util.IdentityHashMap<Tree,double[][][]> oldBinaryTransitions, java.util.IdentityHashMap<Tree,double[][]> newUnaryTransitions, java.util.IdentityHashMap<Tree,double[][][]> newBinaryTransitions, double[] stateWeights, java.util.Map<java.lang.String,int[]> mergeCorrespondence)
          Given a tree and the original set of transition probabilities from one state to the next in the tree, along with a list of the weights in the tree and a count of the mass in each substate at the current node, this method merges the probabilities as necessary.
 void outputBetas()
           
 void outputTransitions(Tree tree, java.util.IdentityHashMap<Tree,double[][]> unaryTransitions, java.util.IdentityHashMap<Tree,double[][][]> binaryTransitions)
           
 void outputTransitions(Tree tree, int depth, java.util.IdentityHashMap<Tree,double[][]> unaryTransitions, java.util.IdentityHashMap<Tree,double[][][]> binaryTransitions)
           
 boolean recalculateBetas(boolean splitStates)
          Recalculates the betas for all known transitions.
 void recalculateMergedBetas(java.util.Map<java.lang.String,int[]> mergeCorrespondence)
           
 void recalculateTemporaryBetas(boolean splitStates, java.util.Map<java.lang.String,double[]> totalStateMass, TwoDimensionalMap<java.lang.String,java.lang.String,double[][]> tempUnaryBetas, ThreeDimensionalMap<java.lang.String,java.lang.String,java.lang.String,double[][][]> tempBinaryBetas)
          Creates temporary beta data structures and fills them in by iterating over the trees.
 void recalculateTemporaryBetas(Tree tree, boolean splitStates, java.util.Map<java.lang.String,double[]> totalStateMass, TwoDimensionalMap<java.lang.String,java.lang.String,double[][]> tempUnaryBetas, ThreeDimensionalMap<java.lang.String,java.lang.String,java.lang.String,double[][][]> tempBinaryBetas)
           
 int recalculateTemporaryBetas(Tree tree, double[] stateWeights, int position, java.util.IdentityHashMap<Tree,double[][]> unaryTransitions, java.util.IdentityHashMap<Tree,double[][][]> binaryTransitions, java.util.Map<java.lang.String,double[]> totalStateMass, TwoDimensionalMap<java.lang.String,java.lang.String,double[][]> tempUnaryBetas, ThreeDimensionalMap<java.lang.String,java.lang.String,java.lang.String,double[][][]> tempBinaryBetas)
           
 int recountInside(Tree tree, boolean splitStates, int loc, java.util.IdentityHashMap<Tree,double[]> probIn)
           
 void recountOutside(Tree tree, java.util.IdentityHashMap<Tree,double[]> probIn, java.util.IdentityHashMap<Tree,double[]> probOut)
           
 void recountOutside(Tree child, Tree parent, java.util.IdentityHashMap<Tree,double[]> probIn, java.util.IdentityHashMap<Tree,double[]> probOut)
           
 void recountOutside(Tree left, Tree right, Tree parent, java.util.IdentityHashMap<Tree,double[]> probIn, java.util.IdentityHashMap<Tree,double[]> probOut)
           
 void recountTree(Tree tree, boolean splitStates, java.util.IdentityHashMap<Tree,double[][]> unaryTransitions, java.util.IdentityHashMap<Tree,double[][][]> binaryTransitions)
           
 void recountTree(Tree tree, boolean splitStates, java.util.IdentityHashMap<Tree,double[]> probIn, java.util.IdentityHashMap<Tree,double[]> probOut, java.util.IdentityHashMap<Tree,double[][]> unaryTransitions, java.util.IdentityHashMap<Tree,double[][][]> binaryTransitions)
           
 void recountWeights(Tree tree, java.util.IdentityHashMap<Tree,double[]> probIn, java.util.IdentityHashMap<Tree,double[]> probOut, java.util.IdentityHashMap<Tree,double[][]> unaryTransitions, java.util.IdentityHashMap<Tree,double[][][]> binaryTransitions)
           
 void recurseOutside(Tree tree, java.util.IdentityHashMap<Tree,double[]> probIn, java.util.IdentityHashMap<Tree,double[]> probOut)
           
 void rescaleTemporaryBetas(TwoDimensionalMap<java.lang.String,java.lang.String,double[][]> tempUnaryBetas, ThreeDimensionalMap<java.lang.String,java.lang.String,java.lang.String,double[][][]> tempBinaryBetas)
           
 void saveTrees(java.util.Collection<Tree> trees1, double weight1, java.util.Collection<Tree> trees2, double weight2)
           
 void splitBetas()
          Before each iteration of splitting states, we have tables of betas which correspond to the transitions between different substates.
 java.lang.String state(java.lang.String tag, int i)
           
 boolean testConvergence(TwoDimensionalMap<java.lang.String,java.lang.String,double[][]> tempUnaryBetas, ThreeDimensionalMap<java.lang.String,java.lang.String,java.lang.String,double[][][]> tempBinaryBetas)
           
 boolean useNewBetas(boolean testConverged, TwoDimensionalMap<java.lang.String,java.lang.String,double[][]> tempUnaryBetas, ThreeDimensionalMap<java.lang.String,java.lang.String,java.lang.String,double[][][]> tempBinaryBetas)
           
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Constructor Detail

SplittingGrammarExtractor

public SplittingGrammarExtractor(Options op)
Method Detail

outputTransitions

public void outputTransitions(Tree tree,
                              java.util.IdentityHashMap<Tree,double[][]> unaryTransitions,
                              java.util.IdentityHashMap<Tree,double[][][]> binaryTransitions)

outputTransitions

public void outputTransitions(Tree tree,
                              int depth,
                              java.util.IdentityHashMap<Tree,double[][]> unaryTransitions,
                              java.util.IdentityHashMap<Tree,double[][][]> binaryTransitions)

outputBetas

public void outputBetas()

state

public java.lang.String state(java.lang.String tag,
                              int i)

getStateSplitCount

public int getStateSplitCount(Tree tree)

getStateSplitCount

public int getStateSplitCount(java.lang.String label)

countOriginalStates

public void countOriginalStates()
Count all the internal labels in all the trees, and set their initial state counts to 1.


splitBetas

public void splitBetas()
Before each iteration of splitting states, we have tables of betas which correspond to the transitions between different substates. When we resplit the states, we duplicate parent states and then split their transitions 50/50 with some random variation between child states.


recalculateBetas

public boolean recalculateBetas(boolean splitStates)
Recalculates the betas for all known transitions. The current betas are used to produce probabilities, which then are used to compute new betas. If splitStates is true, then the probabilities produced are as if the states were split again from the last time betas were calculated.
The return value is whether or not the betas have mostly converged from the last time this method was called. Obviously if splitStates was true, the betas will be entirely different, so this is false. Otherwise, the new betas are compared against the old values, and convergence means they differ by less than EPSILON.


useNewBetas

public boolean useNewBetas(boolean testConverged,
                           TwoDimensionalMap<java.lang.String,java.lang.String,double[][]> tempUnaryBetas,
                           ThreeDimensionalMap<java.lang.String,java.lang.String,java.lang.String,double[][][]> tempBinaryBetas)

recalculateTemporaryBetas

public void recalculateTemporaryBetas(boolean splitStates,
                                      java.util.Map<java.lang.String,double[]> totalStateMass,
                                      TwoDimensionalMap<java.lang.String,java.lang.String,double[][]> tempUnaryBetas,
                                      ThreeDimensionalMap<java.lang.String,java.lang.String,java.lang.String,double[][][]> tempBinaryBetas)
Creates temporary beta data structures and fills them in by iterating over the trees.


testConvergence

public boolean testConvergence(TwoDimensionalMap<java.lang.String,java.lang.String,double[][]> tempUnaryBetas,
                               ThreeDimensionalMap<java.lang.String,java.lang.String,java.lang.String,double[][][]> tempBinaryBetas)

recalculateTemporaryBetas

public void recalculateTemporaryBetas(Tree tree,
                                      boolean splitStates,
                                      java.util.Map<java.lang.String,double[]> totalStateMass,
                                      TwoDimensionalMap<java.lang.String,java.lang.String,double[][]> tempUnaryBetas,
                                      ThreeDimensionalMap<java.lang.String,java.lang.String,java.lang.String,double[][][]> tempBinaryBetas)

recalculateTemporaryBetas

public int recalculateTemporaryBetas(Tree tree,
                                     double[] stateWeights,
                                     int position,
                                     java.util.IdentityHashMap<Tree,double[][]> unaryTransitions,
                                     java.util.IdentityHashMap<Tree,double[][][]> binaryTransitions,
                                     java.util.Map<java.lang.String,double[]> totalStateMass,
                                     TwoDimensionalMap<java.lang.String,java.lang.String,double[][]> tempUnaryBetas,
                                     ThreeDimensionalMap<java.lang.String,java.lang.String,java.lang.String,double[][][]> tempBinaryBetas)

rescaleTemporaryBetas

public void rescaleTemporaryBetas(TwoDimensionalMap<java.lang.String,java.lang.String,double[][]> tempUnaryBetas,
                                  ThreeDimensionalMap<java.lang.String,java.lang.String,java.lang.String,double[][][]> tempBinaryBetas)

recountTree

public void recountTree(Tree tree,
                        boolean splitStates,
                        java.util.IdentityHashMap<Tree,double[][]> unaryTransitions,
                        java.util.IdentityHashMap<Tree,double[][][]> binaryTransitions)

recountTree

public void recountTree(Tree tree,
                        boolean splitStates,
                        java.util.IdentityHashMap<Tree,double[]> probIn,
                        java.util.IdentityHashMap<Tree,double[]> probOut,
                        java.util.IdentityHashMap<Tree,double[][]> unaryTransitions,
                        java.util.IdentityHashMap<Tree,double[][][]> binaryTransitions)

recountWeights

public void recountWeights(Tree tree,
                           java.util.IdentityHashMap<Tree,double[]> probIn,
                           java.util.IdentityHashMap<Tree,double[]> probOut,
                           java.util.IdentityHashMap<Tree,double[][]> unaryTransitions,
                           java.util.IdentityHashMap<Tree,double[][][]> binaryTransitions)

recountOutside

public void recountOutside(Tree tree,
                           java.util.IdentityHashMap<Tree,double[]> probIn,
                           java.util.IdentityHashMap<Tree,double[]> probOut)

recurseOutside

public void recurseOutside(Tree tree,
                           java.util.IdentityHashMap<Tree,double[]> probIn,
                           java.util.IdentityHashMap<Tree,double[]> probOut)

recountOutside

public void recountOutside(Tree child,
                           Tree parent,
                           java.util.IdentityHashMap<Tree,double[]> probIn,
                           java.util.IdentityHashMap<Tree,double[]> probOut)

recountOutside

public void recountOutside(Tree left,
                           Tree right,
                           Tree parent,
                           java.util.IdentityHashMap<Tree,double[]> probIn,
                           java.util.IdentityHashMap<Tree,double[]> probOut)

recountInside

public int recountInside(Tree tree,
                         boolean splitStates,
                         int loc,
                         java.util.IdentityHashMap<Tree,double[]> probIn)

mergeStates

public void mergeStates()

recalculateMergedBetas

public void recalculateMergedBetas(java.util.Map<java.lang.String,int[]> mergeCorrespondence)

mergeTransitions

public void mergeTransitions(Tree parent,
                             java.util.IdentityHashMap<Tree,double[][]> oldUnaryTransitions,
                             java.util.IdentityHashMap<Tree,double[][][]> oldBinaryTransitions,
                             java.util.IdentityHashMap<Tree,double[][]> newUnaryTransitions,
                             java.util.IdentityHashMap<Tree,double[][][]> newBinaryTransitions,
                             double[] stateWeights,
                             java.util.Map<java.lang.String,int[]> mergeCorrespondence)
Given a tree and the original set of transition probabilities from one state to the next in the tree, along with a list of the weights in the tree and a count of the mass in each substate at the current node, this method merges the probabilities as necessary. The results go into newUnaryTransitions and newBinaryTransitions.


countMergeEffects

public void countMergeEffects(Tree tree,
                              java.util.Map<java.lang.String,double[]> totalStateMass,
                              java.util.Map<java.lang.String,double[]> deltaAnnotations)

countMergeEffects

public void countMergeEffects(Tree tree,
                              java.util.Map<java.lang.String,double[]> totalStateMass,
                              java.util.Map<java.lang.String,double[]> deltaAnnotations,
                              java.util.IdentityHashMap<Tree,double[]> probIn,
                              java.util.IdentityHashMap<Tree,double[]> probOut)

buildStateIndex

public void buildStateIndex()

buildGrammars

public void buildGrammars()

saveTrees

public void saveTrees(java.util.Collection<Tree> trees1,
                      double weight1,
                      java.util.Collection<Tree> trees2,
                      double weight2)

extract

public void extract(java.util.Collection<Tree> treeList)

extract

public void extract(java.util.Collection<Tree> trees1,
                    double weight1,
                    java.util.Collection<Tree> trees2,
                    double weight2)
First, we do a few setup steps. We read in all the trees, which is necessary because we continually reprocess them and use the object pointers as hash keys rather than hashing the trees themselves. We then count the initial states in the treebank.
Having done that, we then assign initial probabilities to the trees. At first, each state has 1.0 of the probability mass for each Ax-ByCz and Ax-By transition. We then split the number of states and the probabilities on each tree.
We then repeatedly recalculate the betas and reannotate the weights, going until we converge, which is defined as no betas move more then epsilon.
java -mx4g edu.stanford.nlp.parser.lexparser.LexicalizedParser -PCFG -saveToSerializedFile englishSplit.ser.gz -saveToTextFile englishSplit.txt -maxLength 40 -train ../data/wsj/wsjtwentytrees.mrg -testTreebank ../data/wsj/wsjtwentytrees.mrg -evals "factDA,tsv" -uwm 0 -hMarkov 0 -vMarkov 0 -simpleBinarizedLabels -noRebinarization -predictSplits -splitTrainingThreads 1 -splitCount 1 -splitRecombineRate 0.5
may also need
-smoothTagsThresh 0
java -mx8g edu.stanford.nlp.parser.lexparser.LexicalizedParser -evals "factDA,tsv" -PCFG -vMarkov 0 -hMarkov 0 -uwm 0 -saveToSerializedFile wsjS1.ser.gz -maxLength 40 -train /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj 200-2199 -testTreebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj 2200-2219 -compactGrammar 0 -simpleBinarizedLabels -predictSplits -smoothTagsThresh 0 -splitCount 1 -noRebinarization



Stanford NLP Group