#include "libssa.h"
#include <stdio.h>
#include <stdarg.h>
#include <string.h>
#include <stdlib.h>  
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#include "mt.i"

static char* _allocSprintf(char* fmt,va_list va) {
  // A static buffer that we write to, growable if it proves insufficiently
  // large
  static char* buf=NULL;
  static int bufSz=1024;
  if(buf==NULL) buf=(char*)malloc(bufSz);
  
  for(;;) {
    if(vsnprintf(buf,bufSz,fmt,va)<bufSz-1) break;
    // Otherwise buf wasn't big enough.  Double it and try again
    bufSz*=2;
    if(bufSz>1024*1024) { fprintf(stderr,"runaway _allocSprintf\n"); abort(); }
    buf=realloc(buf,bufSz);
  }
  char* s=(char*)malloc(strlen(buf)+1);
  strcpy(s,buf);
  return s;
} 

Participant* participant(char* fmt,...) {
  va_list va;
  va_start(va,fmt);
  Participant* r=(Participant*)malloc(sizeof(Participant));
  r->name=_allocSprintf(fmt,va);
  va_end(va);
  return r;
}

Term* term(u32 coefficient,Participant* participant) {
  Term* t=(Term*)malloc(sizeof(Term));
  t->coefficient=coefficient;
  t->participant=participant;
  return t;
}

Term** termList(Term* t1,...) {
  va_list va,va2;
  va_start(va,t1);

  // Count the arguments
  int n=1;
  va_copy(va2,va);
  while(va_arg(va2,Term*)) n++;
  va_end(va2);

  // Allocate a list and populate it
  Term** r=(Term**)malloc((n+1)*sizeof(Term*));
  r[0]=t1;
  for(int i=1;(r[i]=va_arg(va,Term*));i++);
  va_end(va);

  return r;
}

SSAReaction* reaction(double propensity,Term** reactants,Term** products) {
  SSAReaction* r=(SSAReaction*)malloc(sizeof(SSAReaction));
  r->propensity=propensity;
  r->reactants=reactants;
  r->products=products;
  return r;
}   
 
ReactionList* reactionList(void) {
  ReactionList* l=(ReactionList*)malloc(sizeof(ReactionList));
  memset(l,0,sizeof(ReactionList));
  // Always keep a NULL on the end of list so that it's ready to pass to libssa
  l->size=1; 
  l->list=(SSAReaction**)malloc(sizeof(SSAReaction*));
  l->list[0]=NULL;
  return l;
}
  
void reactionListAdd(ReactionList* l,SSAReaction* r) {
  if(l->num==l->size-1) {              // Always keep a NULL on the end of list
    l->size+=10;
    l->list=(SSAReaction**)realloc(l->list,sizeof(SSAReaction*)*l->size);
    memset(l->list+l->num,0,sizeof(SSAReaction*)*(l->size-l->num));
  }
  l->list[l->num++]=r;
} 

//-- DirectLinear

typedef struct {
  Participant* ptr;
  u64 n;
} ParticipantData;
    
typedef struct {
  int size;
  int next;
  ParticipantData* d;
} ParticipantDataSet;
  
typedef struct {
  Timeline tl;
  ParticipantDataSet* pd;
  u64* data;
} _Trajectory;

int findPd(ParticipantDataSet* pds,Participant* p) {
  int tpd=-1;
  for(int i=0;i<pds->next;i++) if(pds->d[i].ptr==p) { tpd=i; break; }
  if(tpd==-1) {                        // If it's not there already, add it
    if(pds->next==pds->size) {               // If it's full, grow it
      pds->size+=1024;
      pds->d=(ParticipantData*)realloc(pds->d,
                                       sizeof(ParticipantData)*pds->size);
    }
    pds->d[pds->next].ptr=p;
    tpd=pds->next;
    pds->next++;
  }
  return tpd;
}

static volatile int gccHack=0;

Trajectory* DirectLinear(Problem* p,Timeline* tl) {
#if 0
  Set s;
  setInit(&s);
  f(&s,p);
  _compact(&s);
  printf("There are %d reactants\n",s.num);
  for(int i=0;i<s.num;i++) printf("  %s (%p)\n",
    ((Participant*)s.vals[i])->name,s.vals[i]);
#endif

  // The stupid way: linear search a map from participant ptr to anciliary
  // data.  Realloc when full.  Blech.
  ParticipantDataSet* pd=(ParticipantDataSet*)
    malloc(sizeof(ParticipantDataSet));
  memset(pd,0,sizeof(ParticipantDataSet));

  // Need to assign indices to each of the reactants, so need to traverse
  // them as some kind of reachability thing from the Problem.
  int nr=0;
  for(SSAReaction** rp=p->reactions;*rp;rp++) {
    nr++;
    SSAReaction* r=*rp;
//    printf("SSAReaction:\n");

//    printf("  reactants:");
    for(Term** tp=r->reactants;*tp;tp++) {
      Term* t=*tp;
      int i=findPd(pd,t->participant); i=i;
//      printf(" %lu*%s(%d)",t->coefficient,t->participant->name,i);
    }
//    printf("\n");

//    printf("  products:");
    for(Term** tp=r->products;*tp;tp++) {
      Term* t=*tp;
      int i=findPd(pd,t->participant); i=i;
//      printf(" %lu*%s(%d)",t->coefficient,t->participant->name,i);
    }
//    printf("\n");

  }

//  printf("Initial Conditions:\n");
  for(InitialCondition** icp=p->initialConditions;*icp;icp++) {
    InitialCondition* ic=*icp;
    int oldNext=pd->next;
    int i=findPd(pd,ic->participant); i=i;
    if(pd->next!=oldNext) {
      fprintf(stderr,"Warning: Initial condition references participant %p"
                     " (%s)\n         that is not referenced by Problem.\n",
                     ic->participant,ic->participant->name);
    }
//    printf("  %s(%d) = %llu\n",ic->participant->name,i,ic->population);
  }

  printf("There are %d reactants\n",pd->next);
  printf("There are %d reactions\n",nr);

  int ind=open("in",O_WRONLY|O_CREAT|O_TRUNC,0666);
  if(ind<0) abort();
  FILE* in=fdopen(ind,"w");
  if(!in) abort();
  fprintf(in,"%d\n",pd->next);                        // Number of reactants
  fprintf(in,"%d\n",nr);                             // Number of reactions

  // Shuffle together the initial condition populations for each reactant
  for(int i=0;i<pd->next;i++) pd->d[i].n=0;
  for(InitialCondition** icp=p->initialConditions;*icp;icp++) {
    InitialCondition* ic=*icp;
    int i=findPd(pd,ic->participant);
    pd->d[i].n=ic->population;
  }
  for(int i=0;i<pd->next;i++)
    fprintf(in,FORMAT_U64"%s",pd->d[i].n,(i==pd->next-1)?"":" ");
  fprintf(in,"\n");

  // Now write out the equation vector.  It is two groups per reaction, where
  // a group is a number N followed by N pairs.  Each pair is a participant
  // index followed by a coefficient.
  for(SSAReaction** rp=p->reactions;*rp;rp++) {
    SSAReaction* r=*rp;

    u32 c=0;
    for(Term** tp=r->reactants;*tp;tp++) c++;  // Count them
    fprintf(in,FORMAT_U32" ",c);
    for(Term** tp=r->reactants;*tp;tp++) {     // Emit them
      Term* t=*tp;
      int i=findPd(pd,t->participant);
      fprintf(in,"%d "FORMAT_U32" ",i,t->coefficient);
    }

    c=0;
    for(Term** tp=r->products;*tp;tp++) c++;  // Count them
    fprintf(in,FORMAT_U32" ",c);
    for(Term** tp=r->products;*tp;tp++) {     // Emit them
      Term* t=*tp;
      int i=findPd(pd,t->participant);
      fprintf(in,"%d "FORMAT_U32" ",i,t->coefficient);
    }
  }
  fprintf(in,"\n");

  // Now emit the propensities of the reactions
  for(SSAReaction** rp=p->reactions;*rp;rp++) {
    SSAReaction* r=*rp;
    fprintf(in,"%.20le ",r->propensity);
  }
  fprintf(in,"\n");

  // Record all the species
  fprintf(in,"%d\n",pd->next);                // How many
  for(int i=0;i<pd->next;i++) fprintf(in,"%d ",i);
  fprintf(in,"\n");

  // Record all the reactions
  fprintf(in,"%d\n",nr);
  for(int i=0;i<nr;i++) fprintf(in,"%d ",i);
  fprintf(in,"\n");

  fprintf(in,"0\n");      // No maximum # allowed steps

  fprintf(in,"0\n");      // No solver params

  // GCC is so broken it's ridiculous.  It gets this wrong by one at -O3:
  //   int c=0;
  //   for(double t=tl->t0;t<=tl->end;t+=tl->incr) if(t>=tl->start) c++;
  // with tl=(0,0,1,0.1) it says c=10 not c=11 (even 'tho a second time thru
  // the loop printing things out yields 11 outputs).

  int c=(int)((tl->end-tl->start)/tl->incr)+1;
  fprintf(in,"%d\n",c);
  // Don't add t0; the current solvers have epoch always ==0
  for(int i=0;i<c;i++) fprintf(in,"%lf ",(i*(tl->end-tl->start)/(c-1)));
  fprintf(in,"\n");

  // Emit the MT initial state vector
  for(int i=0;i<sizeof(mtInit)/sizeof(u32);i++) fprintf(in,FORMAT_U32" ",mtInit[i]);
  fprintf(in,"475\n");

  // Show me 1 trajectory
  fprintf(in,"1\n");

  fclose(in);

  system("/home/funalab/share/work/git/libssa/solvers/Cain-0.12/solvers/DirectLinearSearch.exe < in > out");
  unlink("in");

  FILE* out=fopen("out","r");
  if(!out) abort();

  char s[256];

  fscanf(out,"%*d ");                             // num reactants
  fscanf(out,"%*d ");                             // num reactions
  fscanf(out,"%*d ");                             // num reactants to record
  for(int i=0;i<pd->next;i++) fscanf(out,"%*d ");  // list of reactants recorded
  fscanf(out,"%*d ");                             // num reactions to record
  for(int i=0;i<nr;i++) fscanf(out,"%*d ");       // list of reactions recorded
  fscanf(out,"%s ",s);                            // "TrajectoryFrames"
  if(strcmp(s,"TrajectoryFrames")) {
    fprintf(stderr,"Error: expected 'TrajectoryFrames' got '%s'\n",s);
    abort();
  }
  fscanf(out,"%*d ");                             // num time points
  for(int i=0;i<c;i++) fscanf(out,"%*f ");        // t points
  fscanf(out,"%*d ");                             // num trajectories (1)
  for(int i=0;i<sizeof(mtInit)/sizeof(u32);i++) fscanf(out,"%*d "); // MT state
  fscanf(out,"%*d ");                             // mt offset
  fscanf(out,"%200s ",s);
  if(strcmp(s,"success")) {
    fprintf(stderr,"Error: expected 'success' got '%s'\n",s);
    abort();
  }

  u64* traj=(u64*)malloc(c*pd->next*sizeof(u64));

  for(int j=0;j<c;j++)
    for(int i=0;i<pd->next;i++) fscanf(out,FORMAT_U64" ",&traj[j*pd->next+i]);
  fclose(out);

  unlink("out");

  _Trajectory* r=(_Trajectory*)malloc(sizeof(_Trajectory));
  memcpy(&r->tl,tl,sizeof(Timeline));
  r->data=traj;
  r->pd=pd;
  return (Trajectory*)r;
}

void trajectoryWritePts(FILE* f,Trajectory* traj) {
  _Trajectory* tr=(_Trajectory*)traj;
  Timeline* tl=&tr->tl;
  int c=(int)((tl->end-tl->start)/tl->incr)+1;
  for(int i=0;i<c;i++) {
    double t=tl->t0+(i*(tl->end-tl->start)/(c-1));
    fprintf(f,"%lf ",t);
    for(int j=0;j<tr->pd->next;j++)
      fprintf(f,FORMAT_U64" ",tr->data[i*tr->pd->next+j]);
    fprintf(f,"\n");
  }
}

void trajectoryWriteGnuplotScript(Trajectory* traj,Participant** ps,
                                  double xmin,double xmax,char* xl,char* yl) {
  _Trajectory* tr=(_Trajectory*)traj;
  FILE* out=fopen("gp.r","wc");
  fprintf(out,"#!/usr/local/bin/gnuplot -persist\n");
//  fprintf(out,"set terminal postscript color\n");
//  fprintf(out,"set output \"gp.ps\"\n");
  char* colors[]={ "ff0000","00ff00","0000ff","ff00ff",
                   "ff88ff","88ffff","ffff88","8888ff" };
  for(int i=0;i<8;i++) {
    fprintf(out,"set style line %d lt rgbcolor \"#%s\"\n",i+1,colors[i]);
    //fprintf(out,"set style line %d lt 1 lw 1\n",i+1);
  }
  fprintf(out,"set xlabel \"%s\"\n",xl);
  fprintf(out,"set ylabel \"%s\"\n",yl);
  fprintf(out,"plot [%lf:%lf]\\\n",xmin,xmax);
  int color=0;
  for(Participant** p=ps;*p;p++) {
    int i=findPd(tr->pd,*p);
    fprintf(out,"  'gp.d' using ($1):($%d) with lines linestyle %d title '%s'%s \\\n",
      i+2,color+1,tr->pd->d[i].ptr->name,(*(p+1))?",":"");
    color=(color+1)%8;
  }
  fclose(out);
  system("chmod 755 ./gp.r");
}

void trajectoryPlot(Trajectory* traj,Participant** participants) {
  _Trajectory* tr=(_Trajectory*)traj;

  FILE* out=fopen("gp.d","wc");
  trajectoryWritePts(out,traj);
  fclose(out);

  out=fopen("gp.r","wc");
  fprintf(out,"#!/usr/local/bin/gnuplot -persist\n");
//  fprintf(out,"set terminal postscript color\n");
//  fprintf(out,"set output \"gp.ps\"\n");
  fprintf(out,"plot \\\n");

//  for(int i=0;i<tr->pd->next;i++) {
//    fprintf(out,"  'gp.d' using ($1):($%d) with lines title '%s'%s \\\n",
//      i+2,tr->pd->d[i].ptr->name,(i==tr->pd->next-1)?"":",");
//  }

  for(Participant** p=participants;*p;p++) {
    int i=findPd(tr->pd,*p);
    fprintf(out,"  'gp.d' using ($1):($%d) with lines title '%s'%s \\\n",
      i+2,tr->pd->d[i].ptr->name,(*(p+1))?",":"");
  }

  fclose(out);
  system("chmod 755 ./gp.r ; ./gp.r ; gv gp.sps &");
  unlink("gp.r");
  unlink("gp.d");
}

InitialCondition** trajectoryLastAsIC(Trajectory* traj) {
  _Trajectory* tr=(_Trajectory*)traj;
  Timeline* tl=&tr->tl;
  int c=(int)((tl->end-tl->start)/tl->incr)+1;
  int n=tr->pd->next;

  InitialCondition* r=(InitialCondition*)malloc(n*sizeof(InitialCondition));
  InitialCondition** r2=(InitialCondition**)
    malloc((n+1)*sizeof(InitialCondition*));
  for(int i=0;i<n;i++) {
    r[i].participant=tr->pd->d[i].ptr;
    r[i].population=tr->data[(c-1)*n+i];
    r2[i]=r+i;
  }
  r2[n]=0;
  return r2;
}
