/*  Copyright (c) 2005 Romain BONDUE
    This file is part of RutilT.

    RutilT is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    RutilT is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with RutilT; if not, write to the Free Software
    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
*/
/** \file Helper.cxx
    \author Romain BONDUE
    \date 04/01/2006 */
#include <iostream>
#include <sstream>
#include <memory>
#include <map>
#include <exception>
#include <stdexcept>
#include <cstdlib> // strtol(), strtoul()
#include <cstring> // strncmp(), strcmp(), memset()

extern "C"{
#include <shadow.h> // ::getspnam()
#include <crypt.h> // ::crypt()
#include <unistd.h> // ::sleep(), ::read(), ::write(), ::setuid()
}

#include "MsgHandlerFactory.h"
#include "SystemTools.h"
#include "Msg.h"
#include "IMsgHandler.h"
#include "Exceptions.h"
#include "StaticSettings.h"
#include "ErrorsCode.h"



namespace
{
    typedef std::map<unsigned long, nsRoot::IMsgHandler*> pMsgHandlersMap_t;


    class CFormattedSocket : public nsSystem::CLocalSocket
    {
      public :
        bool ReadIntOrStop (int& Buffer) throw (nsErrors::CException)
        {
            const unsigned NbByteRead (Read (reinterpret_cast<char*> (&Buffer),
                                       sizeof (Buffer)));
            if (NbByteRead)
            {
                if (NbByteRead != sizeof (Buffer))
                    throw nsErrors::CException ("Invalid data read, integer"
                                                " expected.",
                                                nsErrors::InvalidData);
                return false;
            }
            return true;

        } // ReadIntOrStop()


        int ReadInt () throw (nsErrors::CException)
        {
            int Buffer;
            if (ReadIntOrStop (Buffer))
                throw nsErrors::CException ("Unexpected end of stream.",
                                            nsErrors::UnexpectedStreamEnd);
            return Buffer;

        } // ReadInt()


        unsigned long ReadULong () throw (nsErrors::CException)
        {
            unsigned long Buffer;
            if (Read (reinterpret_cast<char*> (&Buffer), sizeof (Buffer)) !=
                                                                sizeof (Buffer))
                throw nsErrors::CException ("Invalid data read, unsigned long"
                                            " expected.",
                                            nsErrors::InvalidData);
            return Buffer;

        } // ReadULong()


        void Write (int Value) throw (nsErrors::CSystemExc)
        {
            CLocalSocket::Write (reinterpret_cast<const char*> (&Value),
                                 sizeof (Value));

        } // Write()


        std::string ReadString (unsigned Size) throw (nsErrors::CException,
                                                      std::bad_alloc)
        {
            std::auto_ptr<char> Buffer (new char [Size]);
            if (Read (Buffer.get(), Size) != Size)
                throw nsErrors::CException ("Invalid data read.",
                                            nsErrors::InvalidData);
            return std::string (Buffer.get(), 0, Size);

        } // ReadString()


        void Write (const std::string& Str) throw (nsErrors::CSystemExc)
        {
            Write (Str.size());
            CLocalSocket::Write (Str.data(), Str.size());

        } // Write()


        void Write (const nsRoot::CMsg& Msg) throw (nsErrors::CSystemExc)
        {
            Write (Msg.GetText());
            Write (Msg.GetCode());

        } // Write()

    }; // CFormattedSocket

    
    unsigned long ParseULong (const std::string& Str)
                                                throw (nsErrors::CException)
    {
        char* pEnd;
        const unsigned long Num (strtoul (Str.c_str(), &pEnd, 10));
        if (*pEnd)
            throw nsErrors::CException ("Invalid data read.",
                                        nsErrors::InvalidData);
        return Num;

    } // ParseULong()


    unsigned long ParseLong (const std::string& Str)
                                                throw (nsErrors::CException)
    {
        char* pEnd;
        const long Num (strtol (Str.c_str(), &pEnd, 10));
        if (*pEnd)
            throw nsErrors::CException ("Invalid data read.",
                                        nsErrors::InvalidData);
        return Num;

    } // ParseLong()


    bool IsRootPasswd (const std::string& Password)
                                                throw (nsErrors::CException)
    {
        const ::spwd* const pShadow (::getspnam ("root"));
        if (!pShadow)
            throw nsErrors::CSystemExc ("Can't check root password.");
        const unsigned CstSaltMaxSize (64); // Should be enough...
        char Salt [CstSaltMaxSize + 1]; // + 1 : '\0'
        if (pShadow->sp_pwdp [0] == '$') // MD5, Blowfish (other?)
        {
            unsigned i (0);
            for (unsigned Cpt (0) ; Cpt < 3 ; ++i)
            {
                if (i >= CstSaltMaxSize)
                    throw nsErrors::CException ("Salt buffer overflow.",
                                                nsErrors::SaltBufferOverflow);
                Salt [i] = pShadow->sp_pwdp [i];
                if (Salt [i] == '$') ++Cpt;
            }
            Salt [i] = '\0';
        }
        else // DES
        {
            Salt [0] = pShadow->sp_pwdp [0];
            Salt [1] = pShadow->sp_pwdp [1];
            Salt [2] = '\0';
        }
        const char* const CryptedPassword (::crypt (Password.c_str(), Salt));
        if (!CryptedPassword)
            throw nsErrors::CSystemExc ("Cannot encrypt password.");
        if (!std::strcmp (CryptedPassword, pShadow->sp_pwdp))
            return true;
        else
        {
            ::sleep (3); // For security reasons.
            return false;
        }

    } // IsRootPasswd()

} // anonymous namespace



int main (int argc, char* argv [])
{
    try
    {
        if (::geteuid())
            throw nsErrors::CException ("Must be executed as root.",
                                        nsErrors::HelperNotRoot);
        if (::setuid (0))
            std::cerr << "Helper : Warning, cannot set user id. errno : "
                      << errno << std::endl;
        if (argc != 2)
            throw nsErrors::CException ("Invalid number of arguments.",
                                        nsErrors::InvalidArguments);
        const ::pid_t RutilTPid (ParseLong (argv [1]));
        CFormattedSocket Sock;
        std::ostringstream OsAddr;
        OsAddr << nsRoot::ServerAddr << RutilTPid;
        Sock.Connect (OsAddr.str());
        Sock.SendCredential();
        if (!Sock.CheckCredentialPID (RutilTPid))
            throw nsErrors::CException ("RutilT authentication failed.",
                                        nsErrors::RutilTAuthenticationFailed);

#ifndef NOROOTPASSCHECK
            // Check root password :
        if (Sock.ReadInt() != nsRoot::CheckRootPassword)
            throw nsErrors::CException ("Invalid command, password check"
                                        " expected.", nsErrors::InvalidData);
        if (!IsRootPasswd (Sock.ReadString (Sock.ReadInt())))
        {
            Sock.Write (nsErrors::InvalidRootPassword);
            std::cerr << "Helper : exiting, invalid root password."
                      << std::endl;
            return nsErrors::InvalidRootPassword;
        }
        Sock.Write (0);
#endif // NOROOTPASSCHECK

        nsRoot::IMsgHandler* pCurrentMsgHandler (0);
        unsigned long InstanceCounter (Sock.ReadULong());
        pMsgHandlersMap_t MsgHandlersMap;
        for ( ; ; )
        {
            int Value;
#ifndef NDEBUG
            std::cerr << "Helper : waiting for command.\n";
#endif // NDEBUG
            if (Sock.ReadIntOrStop (Value))
                break;
#ifndef NDEBUG
            std::cerr << "Helper : command received : " << Value << std::endl;
#endif // NDEBUG
            if (Value < 0) // Internal command.
            {
                const unsigned Size (Sock.ReadInt());
                const std::string Data (Sock.ReadString (Size));
                switch (Value)
                {
                  case nsRoot::CreateRemoteHandlerCmd :
                    MsgHandlersMap [InstanceCounter] =
                                                   nsRoot::MakeHandler (Data);
                    ++InstanceCounter;
                  break;

                  case nsRoot::ChangeRemoteHandlerCmd :
                    pCurrentMsgHandler = MsgHandlersMap [ParseULong (Data)];
                    if (!pCurrentMsgHandler)
                        throw nsErrors::CException
                                              ("Invalid instance number.",
                                               nsErrors::InvalidInstanceNum);
                  break;

                  case nsRoot::DeleteRemoteHandlerCmd :
                  {
                    const pMsgHandlersMap_t::iterator Iter
                            (MsgHandlersMap.find (ParseULong (Data)));
                    if (Iter != MsgHandlersMap.end())
                    {
                        delete Iter->second;
                        MsgHandlersMap.erase (Iter);
                    }
                  }
                  break;

                  default :
                    throw nsErrors::CException ("Invalid command.",
                                                nsErrors::InvalidCommand);
                }
                Sock.Write (0); // Everything is ok.
            }
            else // CMsg
            {
                const std::string Data (Sock.ReadString (Value));
                if (pCurrentMsgHandler)
                    Sock.Write ((*pCurrentMsgHandler) (nsRoot::CMsg (Data,
                                                       Sock.ReadInt())));
                else throw nsErrors::CException ("No handler defined.",
                                                 nsErrors::NoHandlerDefined);
            }
        }
        for (std::map<unsigned long, nsRoot::IMsgHandler*>::iterator Iter
            (MsgHandlersMap.begin()) ; Iter != MsgHandlersMap.end() ; ++Iter)
            delete Iter->second;
    }
    catch (const std::bad_alloc& Exc)
    {
        std::cerr << "Helper : " << Exc.what() << std::endl;
        return nsErrors::OutOfMemory;
    }
    catch (const nsErrors::CException& Exc) // Handle CSystemExc too.
    {
        std::cerr << "Helper : " << Exc.GetMsg() << "\nCode : "
                  << Exc.GetCode() << std::endl;
        return Exc.GetCode();
    }
    catch (const std::exception& Exc)
    {
        std::cerr << "Helper : exception : " << Exc.what() << std::endl;
        return nsErrors::UnknownExc;
    }
    catch (...)
    {
        std::cerr << "Helper : unknown execption." << std::endl;
        return nsErrors::UnknownExc;
    }

} // main()
